DI-engine强化学习入门(七)如何自定义神经网络模型

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定:

  1. 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据这些配置自动构建神经网络模型。这种方式简单易用,适用于常见的标准网络结构。
  2. 自定义模型实例并传入Policy。这种方式下,需要用户自己定义Tensorflow/Pytorch模型类,实现前向传播等接口,然后将实例传入Policy中。这种方式灵活度高,用户可以自由设计任意结构的神经网络。但是需要用户比较熟悉网络定义和 Tensorflow/Pytorch接口。

(注:在强化学习中,策略(Policy)是指智能体(Agent)决策的规则。策略是从状态(State)到动作(Action)的映射,它定义了在给定的状态下,智能体应该采取什么动作。策略可以是确定性的(Deterministic)也可以是随机性的(Stochastic)。)
以上两种方式都会在Policy中封装为neural_net属性,策略学习会通过这个网络完成状态的 embedding 以及动作的选择。这套机制和接口为用户提供了必要的灵活性,可以根据具体问题和需求配置各种神经网络模型。
这是红色的文字

Policy 默认使用的模型是什么
DI-engine 中已经实现的 policy,默认使用 default_model 方法中表明的神经网络模型,例如在 SACPolicy 中:

@POLICY_REGISTRY.register('sac')class SACPolicy(Policy):...    def default_model(self) -> Tuple[str, List[str]]:        if self._cfg.multi_agent:            return 'maqac_continuous', ['ding.model.template.maqac']        else:            return 'qac', ['ding.model.template.qac']...

在这段代码中,我们看到的是一个名为 DI-engine 的强化学习框架中的一个策略(Policy)类的一部分。具体来说,它定义了一个使用Soft Actor-Critic, SAC 算法的策略类。这个段落描述了如何在这个框架内设置和使用策略相关的神经网络模型。

让我们逐步解释这段代码:

  1. @POLICY_REGISTRY.register('sac') 是一个装饰器,它将 SACPolicy 类注册到一个名为 POLICY_REGISTRY 的注册器中,并且用 'sac' 作为这个策略的标识符。这样的注册机制允许框架能够根据名字轻松地查找和实例化策略。
  2. class SACPolicy(Policy): 表明 SACPolicy 是从更一般的 Policy 类派生的,它是一个具体的策略实现,使用了 SAC 算法。
  3. default_model 是 SACPolicy 类的一个方法,它定义了该策略默认使用的模型。这个方法返回两个元素的元组:
  • 'maqac_continuous' 或 'qac':这是在模型注册器中注册的模型名字。根据配置是否是多智能体(multi_agent),它返回不同的模型名。
  • ['ding.model.template.maqac'] 或 ['ding.model.template.qac']:这是包含模型类的文件路径的列表。这个路径告诉 DI-engine 在哪里可以找到定义模型的代码。

4.当使用配置文件时,DI-engine 的入口文件将使用 cfg.policy.model 中的参数(比如 obs_shape, action_shape)来实例化提供的模型类。这个过程是自动化的,意味着用户定义好配置,DI-engine 将负责根据这些配置创建并初始化模型。

5.模型类会根据传入的参数构造适当的神经网络。例如,如果传入的 obs_shape 参数表明观测是一个图像,则模型可能会使用卷积层来处理输入;如果观测是一个向量,则可能使用全连接层。

简而言之,这段代码展示了 DI-engine 如何灵活地处理不同类型的策略和模型,以及如何通过配置文件来方便地自定义和实例化这些策略和模型。这种设计允许研究者和开发者能够轻松试验不同的算法和模型架构,而无需直接修改代码。

如何自定义神经网络模型

在 DI-engine 强化学习框架中,每个策略(如 SACPolicy)通常有一个关联的默认模型(通过 default_model 方法指定),这个默认模型是为特定类型的任务设计的。例如,原始的 qac 模型可能是为处理具有一维观测空间的环境设计的,即观测是一个向量。

但是!如果任务是在一个模拟器(如 dmc2gym,一个DeepMind Control Suite到OpenAI Gym接口的适配器)上运行,并且任务是 cartpole-swingup,而且你希望使用观测为像素的输入(即观测是一个图像),那么默认的 qac 模型不足以处理这样的高维度和多通道的输入。在这种情况下,观测空间的形状是 (3, height, width),其中 3 表示图像的颜色通道数(RGB),height 和 width 分别表示图像的高度和宽度。

在 dmc2gym 文档中,from_pixel 参数设定为 True 意味着环境将提供像素级的观测,而 channels_first 设定为 True 表明图像的通道维度是第一维(这是PyTorch等深度学习库通常采用的格式)。

面对这样的情况,如果你想要使用 SAC 算法处理像素级的观测,你需要自定义一个能够处理这种高维观测的模型。所以我们创建一个新的模型类,该类在内部使用卷积神经网络(CNN)来处理输入的图像数据,并适当地修改网络架构以适应任务的特定要求。

自定义模型完成后,可以将这个模型应用到 SACPolicy 中,替换原本的 qac 模型。涉及到以下几个步骤:

  1. 实现一个新的模型类,它继承自某个基础模型类,并覆盖必要的方法以支持像素级输入。
  2. 在策略配置中指定你的自定义模型,以便 DI-engine 使用你提供的模型而不是默认模型。
  3. 确保你的自定义模型注册到 DI-engine 的模型注册器中,这样框架可以识别和使用它。

自定义 model 基本步骤
1. 明确环境 (env) 和策略 (policy)
首先,需要确定你的强化学习任务的具体环境和任务。例如,我们选择 dmc2gym 环境中的 cartpole-swingup 任务,并且决定观测将以像素数据的形式提供,我们的观测空间是一个图像,其形状为 (3, height, width)。下面我们使用 SAC 算法来进行学习。

在这里,from_pixel = True 表明环境将提供基于像素的观测,channels_first = True 表明图像数据的通道维度在前,这通常是深度学习库(如 PyTorch)的标准格式。

2. 查阅策略中的 default_model 是否适用
接下来,需要检查选择的策略是否具有适用于任务的默认模型。这可以通过查看策略的文档或直接查阅源代码来完成。以 DI-engine 中的 SAC 策略为例,可以查看 SACPolicy 类中的 default_model 方法来了解默认模型:

如果进一步看一下 ding.model.template.qac 中的 QAC 模型,咱们可能会发现它仅支持一维的观测空间,而不支持像 (3, height, width) 这样的图像形状。这意味着对于我们的 cartpole-swingup 任务,需要创建一个自定义模型来处理像素级的观测。

3. 自定义模型 (custom_model) 实现
自定义模型的实现需要遵循一些基本的原则,以确保与 DI-engine 框架的兼容性。

a. 实现功能
自定义模型需要实现默认模型中的所有公共方法。包括:

  • __init__: 构造函数,对模型的各个部分进行初始化。
  • forward: 定义模型如何从输入到输出的前向传递。
  • compute_actor: 计算策略网络的输出,即给定观测值时的动作。
  • compute_critic: 计算价值网络的输出,即动作的价值。

b. 保持返回值类型一致
自定义模型的方法需要保证与原始默认模型的返回值类型一致,以便于替换使用。

c. 利用已实现的 encoder 和 head
在 ding/model/common 下有多种 encoder 和 head 的实现,可用于构建不同部分的模型:

  • Encoder: 负责对输入数据进行编码,使其适合后续的处理。例如,ConvEncoder 用于处理图像观测输入,FCEncoder 用于处理一维观测输入。

点击DI-engine强化学习入门(七)如何自定义神经网络模型 - 古月居可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/10711.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RuvarOA协同办公平台 多处 SQL注入漏洞复现

0x01 产品简介 RuvarOA办公自动化系统是广州市璐华计算机科技有限公司采用组件技术和Web技术相结合,基于Windows平台,构建在大型关系数据库管理系统基础上的,以行政办公为核心,以集成融通业务办公为目标,将网络与无线通讯等信息技术完美结合在一起设计而成的新型办公自动…

el-checkbox选中后的值为id,组件显示为label中文

直接上代码 方法一 <el-checkbox v-for"item in list" :key"item.id" :label"item.id">{{中文}} </el-checkbox> 方法二 <el-checkbox-group class"flex_check" v-model"rkStatusList" v-for"item…

续篇——源码部署LAMP环境上线项目——禅道项目

上篇:LNMP环境部署WordPress——使用源码包安装方式部署环境-CSDN博客 目录 一.前提准备 1. 名词区别 2. 下载项目软件包 3. 上传项目源码到虚拟机并解压 二.安装Apache 1. 环境清理 2.关闭Nginx 3. 下载Apache 4. 下载APR组件 4.1 安装apr 4.2 安装apr-util组件 5…

Kotlin: ‘return‘ is not allowed here

报错&#xff1a;以下函数的内部函数return语句报错 Kotlin: return is not allowed here fun testReturn(summary: (String) -> String): String {var msg summary("summary收到参数")println("test内部调用参数&#xff1a;>结果是 &#xff1a;${msg…

数据分享—全国分省河流水系

河流水系数据是日常研究中必备的数据之一&#xff0c;本期推文主要分享全国分省份的水系和河流数据&#xff0c;梧桐君会不定期的更新数据&#xff0c;欢迎长期订阅。 数据预览 山东省河流水系 吉林省河流水系 四川省河流水系 数据获取方式 链接&#xff1a;https://pan.baidu.…

永久免费的多域名通配符SSL证书申请流程

如果拥有多个域名&#xff0c;且有部分域名拥有子域名&#xff0c;那么多域名通配符证书是非常合适的选择。预算有限或者前期测试可以考虑免费版本的&#xff0c;国产证书厂商JoySSL则提供免费的多域名通配符证书 。 具体流程如下 1创建管理账号 登录JoySSL官网&#xff0c;创…

【启明智显分享】国产自主HMI核心板Model3

Model3是一款高性能的工业级HMI&#xff08;人机界面&#xff09;核心板&#xff0c;也是一款纯国产HMI方案&#xff0c;工业级标准&#xff0c;稳定、可靠&#xff1b; 工业级HMI芯片–Model3 纯国产HMI方案 Model3核心板&#xff0c;具有2D加速&#xff0c;PNG解码&…

Day25 代码随想录打卡|栈与队列篇---用队列实现栈

题目&#xff08;leecode T225&#xff09;&#xff1a; 请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09;。 实现 MyStack 类&#xff1a; void push(int x) 将…

Linux之·网络编程·I/O复用·select

系列文章目录 文章目录 前言一、概述1.1 介绍IO复用的概念和作用1.1.1 I/O复用具体使用的场景1.1.2 I/O复用常用函数 二、select函数的重要性和用途2.1 基本的select函数2.2 如何使用FD_SET、FD_CLR等宏来设置和清除文件描述符集合2.3 select()函数函数整体使用框架&#xff1a…

linux性能监控之slabtop

slabtop命令是以实时的方式显示内核slab缓冲区的细节信息&#xff0c;是linux自带的命令 [rootk8s-master ~]# slabtop --helpUsage:slabtop [options]Options:-d, --delay <secs> delay updates-o, --once only display once, then exit-s, --sort <char&…

MVC 过滤器

MVC 过滤器常用有4种 Action过滤器&#xff08;IActionFilter&#xff09; 》 行为过滤器Result过滤器 &#xff08;IResultFilter&#xff09;》 视图过滤器 或 结果过滤器Exception过滤器&#xff08;IExceptionFilter&#xff09;》 异常过滤器Authorization过滤器&#xf…

python零基础知识 - 定义列表的三种方式,循环列表索引值

这一小节&#xff0c;我们将从零基础的角度看一下&#xff0c;python都有哪些定义列表的方式&#xff0c;并且循环这个列表的时候&#xff0c;怎么循环&#xff0c;怎么循环他的索引值&#xff0c;怎么拿到的就是元素值。 说完循环&#xff0c;我们会说一说关键的break和contin…

i春秋-GetFlag

题目 考点 sql注入&#xff0c;md5加密&#xff0c;代码审计&#xff0c;利用eval函数 解题 参考wp https://www.cnblogs.com/qiaowukong/p/13630130.html找md5值 看见验证码中的提示&#xff0c;就是去找一个md5值前六位是指定值的数&#xff08;严格来说不一定是数&…

【userfaultfd+条件竞争劫持modprobe_path】TSGCTF 2021 -- lkgit

前言 入门题&#xff0c;单纯就是完成每日一道 kernel pwn 的 kpi &#x1f600; 题目分析 内核版本&#xff1a;v5.10.25&#xff0c;可以使用 userfaultfd&#xff0c;不存在 cg 隔离开启了 smap/smep/kaslr/kpti 保护开启了 SLAB_HADNERN/RANDOM 保护 题目给了源码&…

第二步->手撕spring源码之bean操作

本步骤目标 本步骤继续完善 Spring Bean 容器框架的功能开发&#xff0c;在这个开发过程中会用到较多的接口、类、抽象类&#xff0c;它们之间会有类的实现、类的继承。 这一次我们把 Bean 的创建交给容器&#xff0c;而不是我们在调用时候传递一个实例化好的 Bean 对象&#x…

【git】通过JetBrains IDE对git的操作

应该适用于所有jetbrains产品。 一、拉取(pull)代码 上方工具栏-Git-克隆。然后填写git地址与本地存放地址。 二、搁置 修改代码后搁置代码&#xff08;不提交&#xff0c;但是也不撤销已修改的代码&#xff0c;把它暂存起来&#xff09;。 界面的左上角。1->2->3。…

【网站项目】SpringBoot803房屋租赁管理系统

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

14.跳跃游戏Ⅱ

文章目录 题目简介题目解答解法一&#xff1a;贪心算法动态规划代码&#xff1a;复杂度分析&#xff1a; 题目链接 大家好&#xff0c;我是晓星航。今天为大家带来的是 跳跃游戏Ⅱ 相关的讲解&#xff01;&#x1f600; 题目简介 题目解答 解法一&#xff1a;贪心算法动态规划…

《QT实用小工具·六十三》QT实现微动背景,界面看似静态实则动态

1、概述 源码放在文章末尾 该项目实现了微动背景&#xff0c;界面看似静态实则动态&#xff0c;风动&#xff0c;幡动&#xff0c;仁者心动&#xff0c;所以到底是什么在动&#xff1f;哈哈~ 界面会偷偷一点一点改动文字颜色的颜色填充。 虽然是动态&#xff0c;但是慢到难以…

Python---Numpy万字总结(2)

NumPy的应用&#xff08;2&#xff09; 数组对象的方法 获取描述统计信息 描述统计信息主要包括数据的集中趋势、离散程度和频数分析等&#xff0c;其中集中趋势主要看均值和中位数&#xff0c;离散程度可以看极值、方差、标准差等 array1 np.random.randint(1, 100, 10) …