ETH开源PPO算法学习

前言

项目地址:https://github.com/leggedrobotics/rsl_rl

项目简介:快速简单的强化学习算法实现,设计为完全在 GPU 上运行。这段代码是 NVIDIA Isaac GYM 提供的 rl-pytorch 的进化版。

下载源码,查看目录,整个项目模块化得非常好,每个部分各司其职。下面我们自底向上地进行讲解加粗的部分。

rsl_rl/
│ __init__.py

├─algorithms/
│ │ __init__.py
│ │ ppo.py # PPO算法的实现
│ │
├─env/
│ │ __init__.py
│ │ vec_env.py # 实现并行处理多个环境的向量化环境
│ │
├─modules/
│ │ __init__.py
│ │ actor_critic.py # 定义 Actor-Critic 网络结构
│ │ actor_critic_recurrent.py # 定义包含循环层的 Actor-Critic 网络
│ │ normalizer.py # 数据正规化工具,有助于训练过程的稳定性
│ │
├─runners/
│ │ __init__.py
│ │ on_policy_runner.py # 实现用于执行 on-policy 算法训练循环的运行器
│ │
├─storage/
│ │ __init__.py
│ │ rollout_storage.py # 存储和管理策略 rollout 数据的工具
│ │
└─utils/
│ __init__.py
│ neptune_utils.py # 用于与 Neptune.ai 集成的工具
│ utils.py # 通用实用工具函数
│ wandb_utils.py # 用于与 Weights & Biases 集成的工具

rollout 数据储存和管理(rollout_storage.py)

定义了一个名为 RolloutStorage 的类,用于存储和管理在强化学习训练过程中从环境中收集到的数据(称为rollouts)。

  • 定义Transition

用于存储单个时间步的所有相关数据,包括观察值、动作、奖励、完成标志(dones)、值函数估计、动作的对数概率、动作的均值和标准差,以及可能的隐藏状态(对于使用循环网络的情况)。

  • 特权观察值(Privileged Observations)

除了self.observations外还有self.privileged_observations的使用,在强化学习中是指那些在训练期间可用但在实际部署或测试时不可用的额外信息。这些信息通常提供了环境的内部状态或其他有助于学习的提示,但在现实世界应用中可能难以获得或完全不可用。在训练期间使用特权观察值的一种常见方法是通过教师-学生架构(我们常常也称作特权学习),其中一个拥有全部信息的教师模型(可以访问特权观察值)来指导一个学生模型(只能访问普通观察值)。学生模型的目标是模仿教师模型的决策,尽管它没有直接访问特权信息。

  • 奖励和优势的计算
    def compute_returns(self, last_values, gamma, lam):advantage = 0for step in reversed(range(self.num_transitions_per_env)):if step == self.num_transitions_per_env - 1:next_values = last_valueselse:next_values = self.values[step + 1]next_is_not_terminal = 1.0 - self.dones[step].float()delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]advantage = delta + next_is_not_terminal * gamma * lam * advantageself.returns[step] = advantage + self.values[step]# Compute and normalize the advantagesself.advantages = self.returns - self.valuesself.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

这段代码实现的是在强化学习中计算回报(returns)和优势(advantages)的逻辑,具体是使用了一种称为广义优势估算(Generalized Advantage Estimation, GAE)的方法。GAE是一种权衡偏差和方差以及平滑回报信号的技术,由以下几个数学公式定义:

  1. TD残差(Temporal Difference Residual):
    δ t = R t + γ V ( S t + 1 ) ( 1 − d o n e t ) − V ( S t ) \delta_t = R_t + \gamma V(S_{t+1}) (1 - done_t) - V(S_t) δt=Rt+γV(St+1)(1donet)V(St)
    其中, δ t \delta_t δt是时刻 t t t的TD残差, R t R_t Rt是奖励, γ \gamma γ是折扣因子, V ( S t ) V(S_t) V(St)是状态 S t S_t St的价值函数估计, d o n e t done_t donet是表示当前状态是否为终止状态的指示函数(如果当前状态为终止状态,则 d o n e t = 1 done_t = 1 donet=1;否则, d o n e t = 0 done_t = 0 donet=0)。如果 d o n e t = 1 done_t = 1 donet=1,那么 γ V ( S t + 1 ) \gamma V(S_{t+1}) γV(St+1)项将为 0,因为终止状态之后没有未来回报。

  2. GAE优势估计:
    A t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE(γ,λ)=l=0(γλ)lδt+l
    在代码中,这个无限求和是通过迭代地计算来近似的,具体的迭代公式为:
    A t = δ t + ( γ λ ) A t + 1 ( 1 − d o n e t ) A_t = \delta_t + (\gamma \lambda) A_{t+1} (1 - done_t) At=δt+(γλ)At+1(1donet)
    其中, A t A_t At是时刻 t t t的优势估计, λ \lambda λ是用来平衡TD估计和蒙特卡罗估计之间权重的参数。

  3. 回报的计算:
    G t = A t + V ( S t ) G_t = A_t + V(S_t) Gt=At+V(St)
    其中, G t G_t Gt是时刻 t t t的回报估计。

代码中使用的变量名与数学符号的对应关系:

变量名数学符号含义
rewards[step] R t R_t Rt时刻 t t t的奖励
gamma γ \gamma γ折扣因子,用于计算未来奖励的现值
values[step] V ( S t ) V(S_t) V(St)状态 S t S_t St在当前策略下的价值函数估计
dones[step] d o n e t done_t donet指示当前状态 S t S_t St是否为终止状态的标志(1 表示终止,0 表示非终止)
delta δ t \delta_t δt时刻 t t t的 TD 残差
advantage A t A_t At时刻 t t t的优势估计,根据 GAE 方法计算
lam λ \lambda λ用于 GAE 计算中平衡 TD 估计和蒙特卡罗估计之间权重的参数
returns[step] G t G_t Gt时刻 t t t的回报估计
advantages A t n o r m A_t^{norm} Atnorm标准化后的优势估计
mu_A, sigma_A μ A \mu_A μA, σ A \sigma_A σA优势估计的平均值和标准差
epsilon ϵ \epsilon ϵ避免除零错误而加的小常数,通常取值为 1e-8

代码中的循环从最后一个转换开始向前迭代,使用以上的数学公式来计算每一步的优势和回报。最后,它还对优势进行了标准化处理,即从每个优势中减去所有优势的平均值,并除以标准差,以减少训练期间的方差并加速收敛。标准化公式如下:
A t n o r m = A t − μ A σ A + ϵ A_t^{norm} = \frac{A_t - \mu_A}{\sigma_A + \epsilon} Atnorm=σA+ϵAtμA
其中, μ A \mu_A μA是优势的平均值, σ A \sigma_A σA是优势的标准差, ϵ \epsilon ϵ​ 是为了防止除以零而加的一个小常数(在代码中为 1e-8)。

  • 轨迹的平均长度

类中并没有显式存储轨迹的长度,轨迹长度隐含在self.dones之中。代码中使用的方法是:将每个环境中最后一步置为‘1’,然后flatten(展开)、拼接所有环境中的dones得到flat_dones,差分数组中为‘1’位置的索引得到智能体在每个环境中的步数,即轨迹长度。这个统计量有助于了解训练过程中智能体的表现。

  • mini-batch迭代器

mini_batch_generator 函数通过在多个训练周期(num_epochs)内,从经验回放缓冲区中随机选择小批量数据(包括观察值 observations、动作 actions、奖励 rewards 等)来生成小批量数据集。该函数利用 torch.randperm 生成随机索引 indices 来随机化数据抽样,进而支持基于批处理的学习方法,如梯度下降。通过每次只处理必要的数据量,该生成器在优化模型参数的同时,也优化了内存使用,确保了训练过程的高效性和灵活性。

(未完待续)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/708331.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

创新之巅 健康之选 森歌集成灶智能水洗新揭秘

2024年2月27日,一场引领智能厨电风潮的盛会在杭州隆重召开。森歌集成灶以“勠力同心 共生共歌”为主题,成功举办了2024森歌智能厨电优秀经销商峰会。此次峰会上,森歌集成灶发布了令人瞩目的奥运冠军同款智能厨电新品——森歌鲸洗小灶Z60&…

前端架构: 脚手架之多package项目管理和架构

多package项目管理 1 )多package项目管理概述 通常来说,当一个项目变大了以后,我们就要对这个项目进行拆分在前端当中,对于项目进行拆分的方式,通常把它称之为javascript包管理需要使用一个工具叫做 npm (Node Packag…

Java开发的核心模式 - MVC

文章目录 1、MVC设计模式2、Web开发本质3、服务器的性能瓶颈 1、MVC设计模式 MVC设计模式示意图 在整个Java学习之旅中,MVC(Model-View-Controller)设计模式无疑占据着极其重要的地位,堪称理解和掌握Java项目开发精髓的钥匙。如…

IP源防攻击IPSG(IP Source Guard)

IP源防攻击IPSG(IP Source Guard)是一种基于二层接口的源IP地址过滤技术,它能够防止恶意主机伪造合法主机的IP地址来仿冒合法主机,还能确保非授权主机不能通过自己指定IP地址的方式来访问网络或攻击网络。 2.1 IPSG基本原理 绑定…

关于delphi6提示[Fatal Error] File not found: ‘System.pas‘

关于delphi6提示[Fatal Error] File not found: System.pas 一、[Fatal Error] File not found: System.pas的原因 1、System.dcu的输出没有覆盖项目引用路径..\..\dcu下 2、注意事项:System.pas等源码不能赋值到..\..\dcu 3、下述控件的Tools-Environment Optio…

如何利用ChatGPT搞科研?论文检索、写作、基金润色、数据分析、科研绘图(全球地图、植被图、箱型图、雷达图、玫瑰图、气泡图、森林图等)

以ChatGPT、LLaMA、Gemini、DALLE、Midjourney、Stable Diffusion、星火大模型、文心一言、千问为代表AI大语言模型带来了新一波人工智能浪潮,可以面向科研选题、思维导图、数据清洗、统计分析、高级编程、代码调试、算法学习、论文检索、写作、翻译、润色、文献辅助…

深入浅出JVM(十七)之并发垃圾收集器CMS

上篇文章介绍用户线程与GC线程并发执行时可能产生的问题以及使用三色标记法演示原始快照和增量更新两种解决方案 这篇文章将主要介绍并发垃圾收集器中的CMS,其中CMS使用增量更新来解决对象消失问题,如果不了解增量更新的同学可以查看上篇文章深入浅出JV…

【k8s 高级调度--污点和容忍】

1、调度概念 在 Kubernetes 中,调度(scheduling)指的是确保 Pod 匹配到合适的节点, 以便 kubelet 能够运行它们。 抢占(Preemption)指的是终止低优先级的 Pod 以便高优先级的 Pod 可以调度运行的过程。 驱逐…

为什么会对猫毛过敏?如何缓解?浮毛克星—宠物空气净化器推荐

猫咪过敏通常是因为它们身上的Fel d1蛋白质导致的,这些蛋白质附着在猫咪的皮屑上。猫咪舔毛的过程会带出这些蛋白质,一旦接触就可能引发过敏症状,比如打喷嚏等。因此,减少空气中的浮毛数量有助于减轻过敏现象。猫用空气净化器可以…

Tomcat架构分析

Tomcat的核心组件 Tomcat将请求器和处理器分离,使用多种请求器支持不同的网络协议,而处理器只有一个。从而网络协议和容器解耦。 Tomcat的容器 Host:Tomcat提供多个域名的服务,其将每个域名都视为一个虚拟的主机,在…

半导体行业案例:Jira与龙智插件助力某半导体企业实现精益项目管理

近日,龙智Atlassian技术团队收到了国内一家大型半导体企业的感谢信。龙智团队提供的半导体行业项目管理解决方案和服务受到了客户的好评: 在龙智团队的支持下,我们的业务取得了喜人的成果和进步。龙智公司的专业服务和产品,是我们…

skiplist(高阶数据结构)

目录 一、概念 二、实现 三、对比 一、概念 skiplist是由William Pugh发明的,最早出现于他在1990年发表的论文《Skip Lists: A Probabilistic Alternative to Balanced Trees》 skiplist本质上是一种查找结构,用于解决算法中的查找问题,…

AI:145-智能监控系统下的行人安全预警与法律合规分析

🚀点击这里跳转到本专栏,可查阅专栏顶置最新的指南宝典~ 🎉🎊🎉 你的技术旅程将在这里启航! 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带关键代码,详细讲解供大家学习,希望…

2024年阿里云2核4G配置服务器测评_ECS和轻量性能测评

阿里云2核4G服务器多少钱一年?2核4G服务器1个月费用多少?2核4G服务器30元3个月、85元一年,轻量应用服务器2核4G4M带宽165元一年,企业用户2核4G5M带宽199元一年。本文阿里云服务器网整理的2核4G参加活动的主机是ECS经济型e实例和u1…

Vue3制作一个可拖拽的小箭头

效果图 可以抓住小箭头进行左右拖拽&#xff0c;不会做git图&#xff0c;所以只有静态效果QAQ 代码 <template><div class"tip"draggable"true"dragstart"start" //拖拽开始时drag"dragging" //拖拽种dragend "…

2024.2.27每日一题

之前是出去旅游了没发&#xff0c;现在开学了&#xff0c;继续每日一题&#xff0c;继续卷&#xff0c;一上来就是困难题&#x1f613;&#xff0c;直接cv大法。 LeetCode 统计树中的合法路径数目 2867. 统计树中的合法路径数目 - 力扣&#xff08;LeetCode&#xff09; 题目…

选择何种操作系统作为网站服务器

选择操作系统时&#xff0c;需考虑稳定性、安全性、成本、兼容性和技术支持等因素&#xff0c;常见选项有Windows Server和Linux发行版。 选择网站服务器的操作系统是一个关键的决策&#xff0c;因为它将影响到网站的性能、稳定性、安全性以及未来的扩展性&#xff0c;目前市场…

数据库之ACID

一、ACID **原子性&#xff08;Atomicity&#xff09;&#xff1a;**即事务是不可分割的最小工作单元&#xff0c;事务内的操作要么全做&#xff0c;要么全不做&#xff0c;不能只做一部分&#xff1b; 一致性&#xff08;Consistency&#xff09;&#xff1a;在事务执行前数据…

【大数据】Flink SQL 语法篇(八):集合、Order By、Limit、TopN

Flink SQL 语法篇&#xff08;八&#xff09;&#xff1a;集合、Order By、Limit、TopN 1.集合操作2.Order By、Limit 子句2.1 Order By 子句2.2 Limit 子句 3.TopN 子句 1.集合操作 集合操作支持 Batch / Streaming 任务。 UNION&#xff1a;将集合合并并且去重。UNION ALL&a…

DataGrip 2023:让数据库开发变得更简单、更高效 mac/win版

JetBrains DataGrip 2023是一款功能强大的数据库IDE&#xff0c;专为数据库开发和管理而设计。通过DataGrip&#xff0c;您可以连接到各种关系型数据库管理系统(RDBMS)&#xff0c;并使用其提供的一组工具来查询、管理、编辑和开发数据库。 DataGrip 2023 软件获取 DataGrip …