ray.rllib-入门实践-12:自定义policy

        在本博客开始之前,先厘清一下几个概念之间的区别与联系:env,  agent,  model, algorithm, policy. 

        强化学习由两部分组成: 环境(env)和智能体(agent)。环境(env)提供观测值和奖励; agent读取观测值,输出动作或决策。agent是algorithm的类对象。 policy是algorithm的子类, 比如ppo, dqn等。因此,自定义policy本质上是自定义algorithm.  algorithm 主要由两部分组成: 网络结构(model)和损失函数(loss)。 网络结构(model)的自定义由上一个博客ray.rllib-入门实践-11: 自定义模型/网络 进行了介绍:在alrorithm外创建新的model类, 通过 AlgorithmConfig类传入algorithm。因此, c从实际操作上, 自定义algorithm就变成了自定义algorithm 的 loss.

        因此,本博客所提到的自定义policy, 本质上就是继承一个Algorithm, 并修改它的loss函数。

        与之前介绍的自定义env, 自定义model一样, 自定义policy也包含三个步骤:

        1. 继承某个Policy, 创建一个新Policy类, 修改它的损失函数。

        2. 把自己的Policy封装为一个Algorithm, 使ray可识别

        3. 配置使用自己的Policy.

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 自定义 policy

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):"""PyTorch policy class used with PPO."""def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): PPOTorchPolicy.__init__(self,observation_space,action_space,config)## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config@override(PPOTorchPolicy) def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):## 原始损失original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。## 新增自定义损失,这里以正则化损失作为示例addiontial_loss = torch.tensor(0.0) ## 自己定义的lossaddiontial_loss = torch.tensor(0.)for param in model.parameters():addiontial_loss += torch.norm(param)## 得到更新后的损失new_loss = original_loss + 0.01 * addiontial_lossreturn new_loss

二、 把自己的policy封装在一个算法中

## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):## 重写 get_default_policy_class 函数, 使其返回自定义的policy def get_default_policy_class(self, config):return MY_PPOTorchPolicy

三、使用自己的策略创建智能体,执行训练

配置方法1:

## 三、使用自己的策略创建智能体,执行训练
## method-1
from ray.tune.logger import pretty_print 
config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = config.build()
result = algo.train()
print(pretty_print(result))

配置方法2:

## 3. 使用新策略执行训练
## method-2
from ray.tune.logger import pretty_print 
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里使用自己的policy
result = algo.train()
print(pretty_print(result))

四、代码汇总:

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print ## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):"""PyTorch policy class used with PPO."""def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): PPOTorchPolicy.__init__(self,observation_space,action_space,config)## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config@override(PPOTorchPolicy) def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):## 原始损失original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。## 新增自定义损失,这里以正则化损失作为示例addiontial_loss = torch.tensor(0.0) ## 自己定义的lossaddiontial_loss = torch.tensor(0.)for param in model.parameters():addiontial_loss += torch.norm(param)## 得到更新后的损失new_loss = original_loss + 0.01 * addiontial_lossreturn new_loss## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):## 重写 get_default_policy_class 函数, 使其返回自定义的policy def get_default_policy_class(self, config):return MY_PPOTorchPolicy## 三、使用自己的策略创建智能体,执行训练
## method-1
# config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
# config = config.environment("CartPole-v1")
# config = config.rollouts(num_rollout_workers=2)
# config = config.framework(framework="torch") 
# algo = config.build()
# result = algo.train()
# print(pretty_print(result))## method-2
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里配置使用自己的policy
result = algo.train()
print(pretty_print(result))

五、后记:

        如何既要修改网络结构,又要修改loss函数, 可以结合 上一篇博客 和本博客共同实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CVE-2024-23897-Jenkins任意文件读取漏洞复现

content Jenkins是什么CVE-2024-23897总结修复建议 Jenkins是什么 Jenkins是一人基于Java开发的、可扩展的持续集成引擎,用于持续、自动地构建/测试软件项目,可以监控一些定时执行的任务。 官网文档: Jenkins是一款开源 CI&CD 软件&…

Lua 环境的安装

1.安装Lua运行环境 本人采用的是在windows系统中使用cmd指令方式进行安装,安装指令如下: winget install "lua for windows" 也曾使用可执行程序安装过,但由于电脑是加密电脑,最后都已失败告终。使用此方式安装可以安…

基于微信小程序的网上订餐管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

【矩阵二分】力扣378. 有序矩阵中第 K 小的元素

给你一个 n x n 矩阵 matrix ,其中每行和每列元素均按升序排序,找到矩阵中第 k 小的元素。 请注意,它是 排序后 的第 k 小元素,而不是第 k 个 不同 的元素。 你必须找到一个内存复杂度优于 O(n2) 的解决方案。 示例 1&#xff1…

基于微信小程序的助农扶贫系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

Effective C++ 规则47: 请使用 Traits Class 表现类型信息

1、背景 C 是一种静态类型语言,类型的特性在编译期就可以被识别和操作。为了更好地利用编译期信息来编写高效、灵活、可维护的代码,C 提供了一些技术来“萃取”或“提取”类型的相关信息。即利用 traits 类来封装和提取类型信息,以便在编译期…

Linux Futex学习笔记

Futex 简介 概述: Futex(Fast Userspace Mutex)是linux的一种特有机制,设计目标是避免传统的线程同步原语(如mutex、条件变量等)在用户空间和内核空间之间频繁的上下文切换。Futex允许在用户空间处理锁定和等待的操作&#xff0…

小柯剧场训练营第一期音乐剧演员与第二期报名拉开帷幕!

在当下社会文化浪潮中,小柯剧场凭借其独特的培养模式和“先看戏后买票”的良心举措,已然成为艺术界的一股清流。1月12日,由“第一期免费训练营”优秀学员们所带来的新一代《稳稳的幸福》成功落幕,引起了社会的广泛关注。该活动不仅…

基于迁移学习的ResNet50模型实现石榴病害数据集多分类图片预测

完整源码项目包获取→点击文章末尾名片! 番石榴病害数据集 背景描述 番石榴 (Psidium guajava) 是南亚的主要作物,尤其是在孟加拉国。它富含维生素 C 和纤维,支持区域经济和营养。不幸的是,番石榴生产受到降…

【论文阅读】HumanPlus: Humanoid Shadowing and Imitation from Humans

作者:Zipeng Fu、Qingqing Zhao、Qi Wu、Gordon Wetstein、Chelsea Finn 项目共同负责人,斯坦福大学 项目网址:https://humanoid-ai.github.io 摘要 制造外形与人类相似的机器人的一个关键理由是,我们可以利用大量的人类数据进行…

2025牛客寒假算法基础集训营2

H 一起画很大的圆&#xff01; 看起来像是一道计算几何的题&#xff0c;实际上通过分析和猜想&#xff0c;是有O1复杂度的结论的。具体证明略&#xff0c;结论是三点越接近共线&#xff0c;得出的半径越大。 #include <bits/stdc.h> using namespace std; #define endl \…

PVE 虚拟机安装 Debian 无图形化界面服务器

Debian 安装 Debian 镜像下载 找一个Debian镜像服务器&#xff0c;根据需要的版本和自己硬件选择。 iso-cd/&#xff1a;较小&#xff0c;仅包含安装所需的基础组件&#xff0c;可能需要网络访问来完成安装。有镜像 debian-12.9.0-amd64-netinst.isoiso-dvd/&#xff1a;较…

硬件学习笔记--35 AD23的使用常规操作

原理图设计 1&#xff09;新建原理图&#xff0c;File-new-Schematic。相关设置参考&#xff0c;主要包含图纸设置以及常规的工具栏。 PCB的设计 新建PCB&#xff0c;设置相应的规则&#xff08;与原理图中相对应&#xff09;&#xff0c;放到同一个工程中。如果有上一版本的…

解读2025年生物医药创新技术:展览会与论坛的重要性

2025生物医药创新技术与应用发展展览会暨论坛&#xff0c;由天津市生物医药行业协会、BIO CHINA生物发酵展组委会携手主办&#xff0c;山东信世会展服务有限公司承办&#xff0c;定于2025年3月3日至5日在济南黄河国际会展中心盛大开幕。展会规模60000平方米、800参展商、35场会…

Poseidon哈希为什么适合做ZKP

论文- https://eprint.iacr.org/2019/458.pdf Poseidon 哈希算法的硬件加速与实现 实用的计算完整性证明系统领域&#xff0c;如 SNARKs、STARKs、Bulletproofs&#xff0c;正在经历非常动态的发展&#xff0c;最近出现了几种具有改进性能和放宽设置要求的结构。此类系统的许多…

开始步入达梦中级dba

分析内存使用需要的方法之一 disql /nolog conn sysdba/sysdbaselect value from v$parameter where nameMEMORY_LEAK_CHECK; SP_SET_PARA_VALUE(0,MEMORY_LEAK_CHECK,1); select * from V$MEM_REGINFO; select * from V$MEM_HEAP;

UDP 广播组播点播的区别及联系

1、网络IP地址的分类 组播地址是分类编址的IPv4地址中的D类地址&#xff0c;又叫多播地址&#xff0c;他的前四位必须是1110&#xff0c;所以网络地址的二进制取值范围是11100000~11101111对应的十进制为 224~~239。所以以224~239开头的网络地址都是组播地址。 组播地址的功能…

opengrok_使用技巧

Searchhttps://xrefandroid.com/android-15.0.0_r1/https://xrefandroid.com/android-15.0.0_r1/ 选择搜索的目录&#xff08;工程&#xff09; 手动在下拉框中选择&#xff0c;或者 使用下面三个快捷按钮进行选择或者取消选择。 输入搜索的条件 搜索域说明 域 fullSearc…

IDEA中Maven使用的踩坑与最佳实践

文章目录 IDEA中Maven使用的踩坑与最佳实践一、环境配置类问题1. Maven环境配置2. IDEA中Maven配置建议 二、常见问题与解决方案1. 依赖下载失败2. 依赖冲突解决3. 编译问题修复 三、效率提升技巧1. IDEA Maven Helper插件使用2. 常用Maven命令配置3. 多模块项目配置4. 资源文件…

Flink读写Kafka(Table API)

前面(Flink读写Kafka(DataStream API)_flink kafka scram-CSDN博客)我们已经讲解了使用DataStream API来读取Kafka,在这里继续讲解下使用Table API来读取Kafka,和前面一样也是引入相同的依赖即可。 <dependency> <groupId>org.apache.flink</groupId&…