ray.rllib-入门实践-11: 自定义模型/网络

在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:

1. 定义自己的模型。

2. 向ray注册自定义的模型

3. 在config中配置使用自定义的模型

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 定义自己的模型 

需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:

import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...

示例如下:

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value 

        在rllib中,每个算法的所有网络都被汇集到同一个 ModelV2 类下,供算法调用。actor 网络和critic网络可以在外面定义,也可以在model内部直接定义。 model的forward用于返回actor网络的输出, value_function函数用于返回critic网络的输出。 网络结构和网络层共享可以自定义设置。输入输出接口,需要与上面保持严格一致。

二、 向ray注册自定义模型

        ray.rllib.model.ModelCatalog 类,用于向ray注册自定义的model, 还可以用于获取env的 preprocessors 和 action distributions。

import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)

三、 在算法中配置并使用自定义的模型

主要是在 config.training() 模块中的 model 子模块中传入两个配置信息:

        1)"custom_model":"my_torch_model" ,                      
         2)"custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  

两个关键字固定不变,填入自己注册的模型名和对应的模型参数即可。

可以有以下三种配置代码的编写方式:

配置方法1:

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置使用自定义的模型
config = config.training(model= {"custom_model":"my_torch_model" ,                      "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
## 主要在上面两行配置使用自己的模型
##    配置 model 的 "custom_model" 项,用于指定rllib算法所使用的模型
##    配置 model 的 "custom_model_config" 项,用于传入自定义的网络参数,供自定义的model使用。
##    这两个关键词不可更改。algo = config.build()
## 4. 执行训练
result = algo.train()
print(pretty_print(result))

与以上配置内容一样,还可以用以下两种配置写法:

配置方法2:

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
model_config_dict = {}
model_config_dict["custom_model"] = "my_torch_model" 
model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
config = config.training(model= model_config_dict)  algo = config.build()

 配置方法3(推荐):

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") ## 配置自定义模型
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}algo = config.build()

 代码汇总:

"""
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。 需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法import torch.nn as nnfrom ray.rllib.models.torch.torch_modelv2 import TorchModelV2class CustomTorchModel(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。 def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...def forward(self, input_dict, state, seq_lens): ...def value_function(self): ...2. 向ray注册自定义的模型from ray.rllib.models import ModelCatalogModelCatalog.register_custom_model("wzg_torch_model", CustomTorchModel)3. 在config中配置使用自定义的模型model_config_dict = {"custom_model":"wzg_torch_model","custom_model_config":{"custom_arg1": 1,"custom_arg2": 2}}config = PPOConfig()# config = config.training(model = model_config_dict)config.model["custom_model"] = "wzg_torch_model"config.model["custom_model_config"] = {"custom_arg1": 1,"custom_arg2": 2}
"""## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDictclass My_Model(TorchModelV2, nn.Module):def __init__(self, obs_space:gym.spaces.Space, action_space:gym.spaces.Space, num_outputs:int, model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数name:str,*, custom_arg1, custom_arg2):TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)nn.Module.__init__(self)## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")## 定义网络层obs_dim = int(np.product(obs_space.shape))action_dim = int(np.product(action_space.shape))## shareNetself.shared_fc = nn.Linear(obs_dim,128)## actorNetself.actorNet = nn.Linear(128, action_dim)## criticNetself.criticNet = nn.Linear(128,1)self._feature = None def forward(self, input_dict, state, seq_lens):obs = input_dict["obs"].float()self._feature = self.shared_fc.forward(obs)action_logits = self.actorNet.forward(self._feature)return action_logits, state def value_function(self):value = self.criticNet.forward(self._feature).squeeze(1)return value ## 2. 向ray注册自定义模型
import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
ray.init()## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") # ## 配置自定义模型:方法 1
# config = config.training(model= {"custom_model":"my_torch_model" ,                      
#                                  "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
# ## 配置自定义模型:方法 2
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config = config.training(model= model_config_dict) ## 配置自定义模型: 方法 3 (个人更喜欢, 因为嵌套层次少)
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}## 错误方法:
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config.model = model_config_dict # 会清空 model 里面的其他默认配置,导致报错algo = config.build()## 4. 执行训练
result = algo.train()
print(pretty_print(result))


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/893725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

攻防世界easyRSA

解密脚本: p473398607161 q4511491 e17def extended_euclidean(a, b):if b 0:return a, 1, 0gcd, x1, y1 extended_euclidean(b, a % b)x y1y x1 - (a // b) * y1return gcd, x, ydef calculate_private_key(p, q, e):phi (p - 1) * (q - 1)gcd, x, y extend…

常见的多媒体框架(FFmpeg GStreamer DirectShow AVFoundation OpenMax)

1.FFmpeg FFmpeg是一个非常强大的开源多媒体处理框架,它提供了一系列用于处理音频、视频和多媒体流的工具和库。它也是最流行且应用最广泛的框架! 官方网址:https://ffmpeg.org/ FFmpeg 的主要特点和功能: 编解码器支持: FFmpe…

.NET MAUI进行UDP通信(二)

上篇文章有写过一个简单的demo&#xff0c;本次对项目进行进一步的扩展&#xff0c;添加tabbar功能。 1.修改AppShell.xaml文件&#xff0c;如下所示&#xff1a; <?xml version"1.0" encoding"UTF-8" ?> <Shellx:Class"mauiDemo.AppShel…

计算机网络之链路层

本文章目录结构出自于《王道计算机考研 计算机网络_哔哩哔哩_bilibili》 02 数据链路层 在网上看到其他人做了详细的笔记&#xff0c;就不再多余写了&#xff0c;直接参考着学习吧。 1 详解数据链路层-数据链路层的功能【王道计算机网络笔记】_wx63088f6683f8f的技术博客_51C…

YOLOv11改进,YOLOv11检测头融合DSConv(动态蛇形卷积),并添加小目标检测层(四头检测),适合目标检测、分割等任务

前言 精确分割拓扑管状结构例如血管和道路,对各个领域至关重要,可确保下游任务的准确性和效率。然而,许多因素使任务变得复杂,包括细小脆弱的局部结构和复杂多变的全局形态。在这项工作中,注意到管状结构的特殊特征,并利用这一知识来引导 DSCNet 在三个阶段同时增强感知…

Flutter android debug 编译报错问题。插件编译报错

下面相关内容 都以 Mac 电脑为例子。 一、问题 起因&#xff1a;&#xff08;更新 Android studio 2024.2.2.13、 Flutter SDK 3.27.2&#xff09; 最近 2025年 1 月 左右&#xff0c;我更新了 Android studio 和 Flutter SDK 再运行就会出现下面的问题。当然 下面的提示只是其…

扣子平台音频功能:让声音也能“智能”起来

在数字化时代&#xff0c;音频内容的重要性不言而喻。无论是在线课程、有声读物&#xff0c;还是各种多媒体应用&#xff0c;音频都是传递信息、增强体验的关键元素。扣子平台的音频功能&#xff0c;为开发者和内容创作者提供了一个强大而灵活的工具&#xff0c;让音频的使用和…

RubyFPV开源代码之系统简介

RubyFPV开源代码之系统简介 1. 源由2. 工程架构3. 特性介绍&#xff08;软件&#xff09;3.1 特性亮点3.2 数字优势3.3 使用功能 4. DEMO推荐&#xff08;硬件&#xff09;4.1 天空端4.2 地面端4.3 按键硬件Raspberry PiRadxa 3W/E/C 5. 软件设计6. 参考资料 1. 源由 RubyFPV以…

将 OneLake 数据索引到 Elasticsearch - 第二部分

作者&#xff1a;来自 Elastic Gustavo Llermaly 及 Jeffrey Rengifo 本文分为两部分&#xff0c;第二部分介绍如何使用自定义连接器将 OneLake 数据索引并搜索到 Elastic 中。 在本文中&#xff0c;我们将利用第 1 部分中学到的知识来创建 OneLake 自定义 Elasticsearch 连接器…

PMP–一、二、三模–分类–14.敏捷

文章目录 敏捷中的角色职责与3个工件--题干关键词角色职责3个工件 高频考点分析&#xff08;一、过程&#xff1b;二、人员&#xff09;一、过程&#xff1a;1.1 变更管理&#xff1a;1.1.1 瀑布型变更&#xff08;一次交付、尽量限制、确定性需求 &#xff1e;风险储备&#x…

Vue2下篇

插槽&#xff1a; 基本插槽&#xff1a; 普通插槽&#xff1a;父组件向子组件传递静态内容。基本插槽只能有一个slot标签&#xff0c;因为这个是默认的位置&#xff0c;所以只能有一个 <!-- ParentComponent.vue --> <template> <ChildComponent> <p>…

【科研建模】Pycaret自动机器学习框架使用流程及多分类项目实战案例详解

Pycaret自动机器学习框架使用流程及项目实战案例详解 1 Pycaret介绍2 安装及版本需求3 Pycaret自动机器学习框架使用流程3.1 Setup3.2 Compare Models3.3 Analyze Model3.4 Prediction3.5 Save Model4 多分类项目实战案例详解4.1 ✅ Setup4.2 ✅ Compare Models4.3 ✅ Experime…

Linux学习笔记——网络管理命令

一、网络基础知识 TCP/IP四层模型 以太网地址&#xff08;MAC地址&#xff09;&#xff1a; 段16进制数据 IP地址&#xff1a; 子网掩码&#xff1a; 二、接口管命令 ip命令&#xff1a;字符终端&#xff0c;立即生效&#xff0c;重启配置会丢失 nmcli命令&#xff1a;字符…

手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion(原理介绍)

手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion&#xff08;原理介绍&#xff09; 目录 手撕Diffusion系列 - 第九期 - 改进为Stable Diffusion&#xff08;原理介绍&#xff09;DDPM 原理图Stable Diffusion 原理Stable Diffusion的原理解释Stable Diffusion 和 Diffus…

JAVAweb学习日记(八) 请数据库模型MySQL

一、MySQL数据模型 二、SQL语言 三、DDL 详细见SQL学习日记内容 四、DQL-条件查询 五、DQL-分组查询 聚合函数&#xff1a; 分组查询&#xff1a; 六、DQL-分组查询 七、分页查询 八、多表设计-一对多&一对一&多对多 一对多-外键&#xff1a; 一对一&#xff1a; 多…

微信小程序1.1 微信小程序介绍

1.1 微信小程序介绍 内容提要 1.1 什么是微信小程序 1.2 微信小程序的功能 1.3 微信小程序使用场景 1.4 微信小程序能取代App吗 1.5 微信小程序的发展历程 1.6微信小程序带来的机会

音频入门(一):音频基础知识与分类的基本流程

音频信号和图像信号在做分类时的基本流程类似&#xff0c;区别就在于预处理部分存在不同&#xff1b;本文简单介绍了下音频处理的方法&#xff0c;以及利用深度学习模型分类的基本流程。 目录 一、音频信号简介 1. 什么是音频信号 2. 音频信号长什么样 二、音频的深度学习分…

Midjourney中的强变化、弱变化、局部重绘的本质区别以及其有多逆天的功能

开篇 Midjourney中有3个图片“微调”&#xff0c;它们分别为&#xff1a; 强变化&#xff1b;弱变化&#xff1b;局部重绘&#xff1b; 在Discord里分别都是用命令唤出的&#xff0c;但如今随着AI技术的发达在类似AI可人一类的纯图形化界面中&#xff0c;我们发觉这样的逆天…

【Linux】命令为桥,存在为岸,穿越虚拟世界的哲学之道

文章目录 Linux基础入门&#xff1a;探索操作系统的内核与命令一、Linux背景与发展历史1.1 Linux的起源与发展1.2 Linux与Windows的对比 二、Linux的常用命令2.1 ls命令 - "List"&#xff08;列出文件)2.2 pwd命令 - "Print Working Directory"&#xff08…

[护网杯 2018]easy_tornado1

题目 、 依次点击文件查看 /flag.txt flag in /fllllllllllllag /welcome.txt render /hints.txt md5(cookie_secretmd5(filename)) tornado模板注入 报cookie /error?msg{{handler.settings}} cookie_secret: 6647062b-e68d-4406-90d3-06e307fa955c} 使用python脚本…