DI-engine强化学习入门(七)如何自定义神经网络模型

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定:

  1. 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据这些配置自动构建神经网络模型。这种方式简单易用,适用于常见的标准网络结构。
  2. 自定义模型实例并传入Policy。这种方式下,需要用户自己定义Tensorflow/Pytorch模型类,实现前向传播等接口,然后将实例传入Policy中。这种方式灵活度高,用户可以自由设计任意结构的神经网络。但是需要用户比较熟悉网络定义和 Tensorflow/Pytorch接口。

(注:在强化学习中,策略(Policy)是指智能体(Agent)决策的规则。策略是从状态(State)到动作(Action)的映射,它定义了在给定的状态下,智能体应该采取什么动作。策略可以是确定性的(Deterministic)也可以是随机性的(Stochastic)。)
以上两种方式都会在Policy中封装为neural_net属性,策略学习会通过这个网络完成状态的 embedding 以及动作的选择。这套机制和接口为用户提供了必要的灵活性,可以根据具体问题和需求配置各种神经网络模型。
这是红色的文字

Policy 默认使用的模型是什么
DI-engine 中已经实现的 policy,默认使用 default_model 方法中表明的神经网络模型,例如在 SACPolicy 中:

@POLICY_REGISTRY.register('sac')class SACPolicy(Policy):...    def default_model(self) -> Tuple[str, List[str]]:        if self._cfg.multi_agent:            return 'maqac_continuous', ['ding.model.template.maqac']        else:            return 'qac', ['ding.model.template.qac']...

在这段代码中,我们看到的是一个名为 DI-engine 的强化学习框架中的一个策略(Policy)类的一部分。具体来说,它定义了一个使用Soft Actor-Critic, SAC 算法的策略类。这个段落描述了如何在这个框架内设置和使用策略相关的神经网络模型。

让我们逐步解释这段代码:

  1. @POLICY_REGISTRY.register('sac') 是一个装饰器,它将 SACPolicy 类注册到一个名为 POLICY_REGISTRY 的注册器中,并且用 'sac' 作为这个策略的标识符。这样的注册机制允许框架能够根据名字轻松地查找和实例化策略。
  2. class SACPolicy(Policy): 表明 SACPolicy 是从更一般的 Policy 类派生的,它是一个具体的策略实现,使用了 SAC 算法。
  3. default_model 是 SACPolicy 类的一个方法,它定义了该策略默认使用的模型。这个方法返回两个元素的元组:
  • 'maqac_continuous' 或 'qac':这是在模型注册器中注册的模型名字。根据配置是否是多智能体(multi_agent),它返回不同的模型名。
  • ['ding.model.template.maqac'] 或 ['ding.model.template.qac']:这是包含模型类的文件路径的列表。这个路径告诉 DI-engine 在哪里可以找到定义模型的代码。

4.当使用配置文件时,DI-engine 的入口文件将使用 cfg.policy.model 中的参数(比如 obs_shape, action_shape)来实例化提供的模型类。这个过程是自动化的,意味着用户定义好配置,DI-engine 将负责根据这些配置创建并初始化模型。

5.模型类会根据传入的参数构造适当的神经网络。例如,如果传入的 obs_shape 参数表明观测是一个图像,则模型可能会使用卷积层来处理输入;如果观测是一个向量,则可能使用全连接层。

简而言之,这段代码展示了 DI-engine 如何灵活地处理不同类型的策略和模型,以及如何通过配置文件来方便地自定义和实例化这些策略和模型。这种设计允许研究者和开发者能够轻松试验不同的算法和模型架构,而无需直接修改代码。

如何自定义神经网络模型

在 DI-engine 强化学习框架中,每个策略(如 SACPolicy)通常有一个关联的默认模型(通过 default_model 方法指定),这个默认模型是为特定类型的任务设计的。例如,原始的 qac 模型可能是为处理具有一维观测空间的环境设计的,即观测是一个向量。

但是!如果任务是在一个模拟器(如 dmc2gym,一个DeepMind Control Suite到OpenAI Gym接口的适配器)上运行,并且任务是 cartpole-swingup,而且你希望使用观测为像素的输入(即观测是一个图像),那么默认的 qac 模型不足以处理这样的高维度和多通道的输入。在这种情况下,观测空间的形状是 (3, height, width),其中 3 表示图像的颜色通道数(RGB),height 和 width 分别表示图像的高度和宽度。

在 dmc2gym 文档中,from_pixel 参数设定为 True 意味着环境将提供像素级的观测,而 channels_first 设定为 True 表明图像的通道维度是第一维(这是PyTorch等深度学习库通常采用的格式)。

面对这样的情况,如果你想要使用 SAC 算法处理像素级的观测,你需要自定义一个能够处理这种高维观测的模型。所以我们创建一个新的模型类,该类在内部使用卷积神经网络(CNN)来处理输入的图像数据,并适当地修改网络架构以适应任务的特定要求。

自定义模型完成后,可以将这个模型应用到 SACPolicy 中,替换原本的 qac 模型。涉及到以下几个步骤:

  1. 实现一个新的模型类,它继承自某个基础模型类,并覆盖必要的方法以支持像素级输入。
  2. 在策略配置中指定你的自定义模型,以便 DI-engine 使用你提供的模型而不是默认模型。
  3. 确保你的自定义模型注册到 DI-engine 的模型注册器中,这样框架可以识别和使用它。

自定义 model 基本步骤
1. 明确环境 (env) 和策略 (policy)
首先,需要确定你的强化学习任务的具体环境和任务。例如,我们选择 dmc2gym 环境中的 cartpole-swingup 任务,并且决定观测将以像素数据的形式提供,我们的观测空间是一个图像,其形状为 (3, height, width)。下面我们使用 SAC 算法来进行学习。

在这里,from_pixel = True 表明环境将提供基于像素的观测,channels_first = True 表明图像数据的通道维度在前,这通常是深度学习库(如 PyTorch)的标准格式。

2. 查阅策略中的 default_model 是否适用
接下来,需要检查选择的策略是否具有适用于任务的默认模型。这可以通过查看策略的文档或直接查阅源代码来完成。以 DI-engine 中的 SAC 策略为例,可以查看 SACPolicy 类中的 default_model 方法来了解默认模型:

如果进一步看一下 ding.model.template.qac 中的 QAC 模型,咱们可能会发现它仅支持一维的观测空间,而不支持像 (3, height, width) 这样的图像形状。这意味着对于我们的 cartpole-swingup 任务,需要创建一个自定义模型来处理像素级的观测。

3. 自定义模型 (custom_model) 实现
自定义模型的实现需要遵循一些基本的原则,以确保与 DI-engine 框架的兼容性。

a. 实现功能
自定义模型需要实现默认模型中的所有公共方法。包括:

  • __init__: 构造函数,对模型的各个部分进行初始化。
  • forward: 定义模型如何从输入到输出的前向传递。
  • compute_actor: 计算策略网络的输出,即给定观测值时的动作。
  • compute_critic: 计算价值网络的输出,即动作的价值。

b. 保持返回值类型一致
自定义模型的方法需要保证与原始默认模型的返回值类型一致,以便于替换使用。

c. 利用已实现的 encoder 和 head
在 ding/model/common 下有多种 encoder 和 head 的实现,可用于构建不同部分的模型:

  • Encoder: 负责对输入数据进行编码,使其适合后续的处理。例如,ConvEncoder 用于处理图像观测输入,FCEncoder 用于处理一维观测输入。

点击DI-engine强化学习入门(七)如何自定义神经网络模型 - 古月居可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/10711.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 107:二叉树的层次遍历II

给你二叉树的根节点 root &#xff0c;返回其节点值 自底向上的层序遍历 。 &#xff08;即按从叶子节点所在层到根节点所在的层&#xff0c;逐层从左向右遍历&#xff09;。 思路&#xff1a;翻转title102的结果即可。 //层次遍历二叉树public static List<List<Integ…

RuvarOA协同办公平台 多处 SQL注入漏洞复现

0x01 产品简介 RuvarOA办公自动化系统是广州市璐华计算机科技有限公司采用组件技术和Web技术相结合,基于Windows平台,构建在大型关系数据库管理系统基础上的,以行政办公为核心,以集成融通业务办公为目标,将网络与无线通讯等信息技术完美结合在一起设计而成的新型办公自动…

安卓LeakCanary研究

一、安卓LeakCanary概述 LeakCanary是Square公司推出的一款开源的内存泄漏检测工具&#xff0c;专为Android平台设计。它通过简洁直观的方式帮助开发者识别和解决应用程序中的内存泄漏问题&#xff0c;从而优化应用性能&#xff0c;减少崩溃风险。LeakCanary的核心优势在于其自…

el-checkbox选中后的值为id,组件显示为label中文

直接上代码 方法一 <el-checkbox v-for"item in list" :key"item.id" :label"item.id">{{中文}} </el-checkbox> 方法二 <el-checkbox-group class"flex_check" v-model"rkStatusList" v-for"item…

react 逻辑 AND 运算符 ()

在 React 组件中&#xff0c;当你想在条件为 true 时渲染一些 JSX 时&#xff0c;它经常会出现&#xff0c;或者什么都不渲染。使用 &#xff0c;只有在以下情况下才能有条件地呈现复选标记&#xff1a;&&isPackedtrue return (<li className"item">{…

续篇——源码部署LAMP环境上线项目——禅道项目

上篇:LNMP环境部署WordPress——使用源码包安装方式部署环境-CSDN博客 目录 一.前提准备 1. 名词区别 2. 下载项目软件包 3. 上传项目源码到虚拟机并解压 二.安装Apache 1. 环境清理 2.关闭Nginx 3. 下载Apache 4. 下载APR组件 4.1 安装apr 4.2 安装apr-util组件 5…

MySQL运维总结

以下是个人工作中用到的mysql运维总结。 基本运维命令 看下死锁的语句&#xff1a;show engine innodb status \G; 修改最大连接数&#xff1a;set global max_connections1400; 使用profile查询sql执行耗时&#xff1a; 1、set profiling 1 ; 启用profile , session级别的配…

Kotlin: ‘return‘ is not allowed here

报错&#xff1a;以下函数的内部函数return语句报错 Kotlin: return is not allowed here fun testReturn(summary: (String) -> String): String {var msg summary("summary收到参数")println("test内部调用参数&#xff1a;>结果是 &#xff1a;${msg…

数据分享—全国分省河流水系

河流水系数据是日常研究中必备的数据之一&#xff0c;本期推文主要分享全国分省份的水系和河流数据&#xff0c;梧桐君会不定期的更新数据&#xff0c;欢迎长期订阅。 数据预览 山东省河流水系 吉林省河流水系 四川省河流水系 数据获取方式 链接&#xff1a;https://pan.baidu.…

C++容器常用集合(附传送门)

C常用的容器&#xff1a; string容器 C容器——string-CSDN博客 储存字符串的 vector容器 C容器——vector-CSDN博客 向量是动态数组&#xff0c;可以自动扩展以容纳更多元素。 插入和删除元素的时间复杂度取决于操作的位置 tuple容器&#xff08;元组&#xff09; C容器…

永久免费的多域名通配符SSL证书申请流程

如果拥有多个域名&#xff0c;且有部分域名拥有子域名&#xff0c;那么多域名通配符证书是非常合适的选择。预算有限或者前期测试可以考虑免费版本的&#xff0c;国产证书厂商JoySSL则提供免费的多域名通配符证书 。 具体流程如下 1创建管理账号 登录JoySSL官网&#xff0c;创…

【启明智显分享】国产自主HMI核心板Model3

Model3是一款高性能的工业级HMI&#xff08;人机界面&#xff09;核心板&#xff0c;也是一款纯国产HMI方案&#xff0c;工业级标准&#xff0c;稳定、可靠&#xff1b; 工业级HMI芯片–Model3 纯国产HMI方案 Model3核心板&#xff0c;具有2D加速&#xff0c;PNG解码&…

AI学习指南概率论篇-概率分布

AI学习指南概率论篇-概率分布 概率分布的概述 概率分布是概率论中的一个重要概念&#xff0c;用于描述随机变量的取值和其对应的概率。概率分布可以帮助我们理解和预测事件发生的可能性&#xff0c;并在AI中扮演着重要角色。在机器学习和深度学习中&#xff0c;概率分布被广泛…

【Python单点知识】深入理解与应用类多态

文章目录 0. 前言1. 多态类的概念2. Python中实现多态类的途径2.1 类的继承2.2 抽象基类2.3 duck typing 3. 多态类的应用场景4. 结论 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我自己学习的理解&#xff0c;虽然参考了他人的宝贵见解及成果&#xff0c;但…

Day25 代码随想录打卡|栈与队列篇---用队列实现栈

题目&#xff08;leecode T225&#xff09;&#xff1a; 请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09;。 实现 MyStack 类&#xff1a; void push(int x) 将…

嵌入式学习day17

FIFO FIFO也称命名管道&#xff0c;它是一种文件类型 特点 FIFO可以在无关的进程之间交换数据&#xff0c;与无名管道不同FIFO有路径名与之相关联&#xff0c;它以一种特殊设备文件形式存在于文件系统中。FIFO的通信方式类似于在进程中使用文件来传输数据&#xff0c;只不过…

Linux之·网络编程·I/O复用·select

系列文章目录 文章目录 前言一、概述1.1 介绍IO复用的概念和作用1.1.1 I/O复用具体使用的场景1.1.2 I/O复用常用函数 二、select函数的重要性和用途2.1 基本的select函数2.2 如何使用FD_SET、FD_CLR等宏来设置和清除文件描述符集合2.3 select()函数函数整体使用框架&#xff1a…

linux性能监控之slabtop

slabtop命令是以实时的方式显示内核slab缓冲区的细节信息&#xff0c;是linux自带的命令 [rootk8s-master ~]# slabtop --helpUsage:slabtop [options]Options:-d, --delay <secs> delay updates-o, --once only display once, then exit-s, --sort <char&…

服务器硬件命令查看

服务器硬件命令查看 1. 主板 sudo dmidecode -t baseboard2. CPU # CPU型号&#xff08;product→ version&#xff09;、CPU名称&#xff08;id 约定一个名称 cpu cpu:0、cpu:1&#xff09;、厂商(vendor)、主频(size HZ)、核数、架构&#xff08;缺失 product&#xf…

MVC 过滤器

MVC 过滤器常用有4种 Action过滤器&#xff08;IActionFilter&#xff09; 》 行为过滤器Result过滤器 &#xff08;IResultFilter&#xff09;》 视图过滤器 或 结果过滤器Exception过滤器&#xff08;IExceptionFilter&#xff09;》 异常过滤器Authorization过滤器&#xf…