JoyRL Actor-Critic算法

策略梯度算法的缺点

这里策略梯度算法特指蒙特卡洛策略梯度算法,即 REINFORCE 算法。 相比于 DQN 之类的基于价值的算法,策略梯度算法有以下优点。

  • 适配连续动作空间。在将策略函数设计的时候我们已经展开过,这里不再赘述。
  • 适配随机策略。由于策略梯度算法是基于策略函数的,因此可以适配随机策略,而基于价值的算法则需要一个确定的策略。此外其计算出来的策略梯度是无偏的,而基于价值的算法则是有偏的。

 但同样的,策略梯度算法也有其缺点。

  • 采样效率低。由于使用的是蒙特卡洛估计,与基于价值算法的时序差分估计相比其采样速度必然是要慢很多的,这个问题在前面相关章节中也提到过。
  • 高方差。虽然跟基于价值的算法一样都会导致高方差,但是策略梯度算法通常是在估计梯度时蒙特卡洛采样引起的高方差,这样的方差甚至比基于价值的算法还要高。
  • 收敛性差。容易陷入局部最优,策略梯度方法并不保证全局最优解,因为它们可能会陷入局部最优点。策略空间可能非常复杂,存在多个局部最优点,因此算法可能会在局部最优点附近停滞。
  • 难以处理高维离散动作空间:对于离散动作空间,采样的效率可能会受到限制,因为对每个动作的采样都需要计算一次策略。当动作空间非常大时,这可能会导致计算成本的急剧增加。

结合了策略梯度和值函数的 Actor-Critic 算法则能同时兼顾两者的优点,并且甚至能缓解两种方法都很难解决的高方差问题。

Q:为什么各自都有高方差的问题,结合了之后反而缓解了这个问题呢?

A:策略梯度算法是因为直接对策略参数化,相当于既要利用策略去与环境交互采样,又要利用采样去估计策略梯度,而基于价值的算法也是需要与环境交互采样来估计值函数的,因此也会有高方差的问题。

 而结合之后呢,Actor 部分还是负责估计策略梯度和采样,但 Critic 即原来的值函数部分就不需要采样而只负责估计值函数了,并且由于它估计的值函数指的是策略函数的值,相当于带来了一个更稳定的估计,来指导 Actor 的更新,反而能够缓解策略梯度估计带来的方差。

Q Actor-Critic算法

如图 10.1 所示,我们通常将 Actor 和 Critic 分别用两个模块来表示,即图中的策略函数( Policy )和价值函数( Value Function )。Actor与环境交互采样,然后将采样的轨迹输入 Critic 网络,Critic 网络估计出当前状态-动作对的价值,然后再将这个价值作为 Actor 网络的梯度更新的依据,这也是所有 Actor-Critic 算法的基本通用架构

A2C与A3C算法

A2C

A3C

广义优势估计

由于优势函数通本质上来说还是使用蒙特卡洛估计,因此尽管减去了基线,有时候还是会产生高方差,从而导致训练过程不稳定

实战:A2C算法

定义模型

Critic 的输入是状态,输出则是一个维度的价值,而 Actor 输入的也会状态,但输出的是概率分布

class Critic(nn.Module):def __init__(self,state_dim):self.fc1 = nn.Linear(state_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))value = self.fc3(x)return valueclass Actor(nn.Module):def __init__(self, state_dim, action_dim):self.fc1 = nn.Linear(state_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, action_dim)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))logits_p = F.softmax(self.fc3(x), dim=1)return logits_p

这里由于是离散的动作空间,根据在策略梯度章节中设计的策略函数,我们使用了 softmax 函数来输出概率分布。另外,实践上来看,由于 Actor 和 Critic 的输入是一样的,因此我们可以将两个网络合并成一个网络,以便于加速训练。这有点类似于 Duelling DQN 算法中的做法

class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim):self.fc1 = nn.Linear(state_dim, 256)self.fc2 = nn.Linear(256, 256)self.action_layer = nn.Linear(256, action_dim)self.value_layer = nn.Linear(256, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))logits_p = F.softmax(self.action_layer(x), dim=1)value = self.value_layer(x)return logits_p, value

动作采样

与 DQN 算法不同等确定性策略不同,A2C 的动作输出不再是 Q 值最大对应的动作,而是从概率分布中采样动作,这意味着即使是很小的概率,也有可能被采样到,这样就能保证探索性

# Categorical分布函数,能直接从概率分布中采样动作
from torch.distributions import Categorical
class Agent:def __init__(self):self.model = ActorCritic(state_dim, action_dim)def sample_action(self,state):'''动作采样函数'''state = torch.tensor(state, device=self.device, dtype=torch.float32)logits_p, value = self.model(state)dist = Categorical(logits_p) action = dist.sample() return action

策略更新

我们首先需要计算出优势函数,一般先计算出回报,然后减去网络输出的值即可

class Agent:# 定义一个Agent类def _compute_returns(self, rewards, dones):# 计算回报returns = []  # 初始化一个回报列表discounted_sum = 0  # 初始化折扣累计和# 从后向前遍历奖励和是否结束的序列for reward, done in zip(reversed(rewards), reversed(dones)):# 如果游戏结束,则折扣累计和重置为0if done:discounted_sum = 0# 否则,将奖励加上折现因子gamma乘以之前的折扣累计和discounted_sum = reward + (self.gamma * discounted_sum)# 将计算出的折扣累计和添加到回报列表的开头returns.insert(0, discounted_sum)# 将回报列表转换为PyTorch张量,并移到Agent指定的设备上returns = torch.tensor(returns, device=self.device, dtype=torch.float32).unsqueeze(dim=1)# 对回报进行归一化处理returns = (returns - returns.mean()) / (returns.std() + 1e-5)  # 添加一个很小的数以避免除以零return returnsdef compute_advantage(self):'''计算优势函数'''# 从记忆库中随机抽取一批经验logits_p, states, rewards, dones = self.memory.sample()# 计算回报returns = self._compute_returns(rewards, dones)# 将状态转换为PyTorch张量,并移到Agent指定的设备上states = torch.tensor(states, device=self.device, dtype=torch.float32)# 前向传播模型以获得动作的概率和对数概率logits_p, values = self.model(states)# 计算优势,即回报与批评价值的差advantages = returns - valuesreturn advantages

这里我们使用了一个技巧,即将回报归一化,这样可以让优势函数的值域在 [−1,1] 之间,这样可以让优势函数更稳定,从而减少方差。计算优势之后就可以分别计算 Actor 和 Critic 的损失函数了

class Agent:def compute_loss(self):'''计算损失函数'''logits_p, states, rewards, dones = self.memory.sample()returns = self._compute_returns(rewards, dones)states = torch.tensor(states, device=self.device, dtype=torch.float32)logits_p, values = self.model(states)advantages = returns - valuesdist = Categorical(logits_p)log_probs = dist.log_prob(actions)# 注意这里策略损失反向传播时不需要优化优势函数,因此需要将其 detach 掉actor_loss = -(log_probs * advantages.detach()).mean() critic_loss = advantages.pow(2).mean()return actor_loss, critic_loss

练习题

1.相比于 REINFORCE 算法, A2C 主要的改进点在哪里,为什么能提高速度?

(1)结合了策略梯度和值函数的 Actor-Critic 算法则能同时兼顾两者的优点,并且甚至能缓解两种方法都很难解决的高方差问题

(2)A2C计算了一个优势函数来衡量实际回报与批评价值之间的差异

(3)A2C在计算回报时使用了均值标准化,这有助于加快学习的收敛速度

2.A2C 算法是 on-policy 的吗?为什么?

是的。A2C算法通过Actor-Critic实现on-policy学习。Actor负责生成行动的概率分布,而Critic负责评估状态的价值。在A2C的更新过程中,智能体根据Actor生成的策略选择行动,并使用这些行动的结果来更新Actor和Critic。因此,A2C在执行和学习时使用的是同一策略

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/641855.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB数据处理: 每种样本类型随机抽样

tn5;% 每种类型随机抽样数 indextrain[];% 训练样本序号集 for i1:typenumber index301 find(typemat i); n2length(index301); index302randperm(n2); index401index301(index302(1:tn)); indextrain[indextrain; index401]; end 该代码可以对大样…

java进阶

文章目录 一、Java进阶1.注解(Annotation)a.内置注解b.元注解c.自定义注解 2.对象克隆3. Java设计模式(Java design patterns)a.软件设计模式概念b.建模语言(UML)c.面向对象设计原则d.设计模式 总结面向对象…

项目工程下载与XML配置文件下载:EtherCAT超高速实时运动控制卡XPCIE1032H上位机C#开发(十)

XPCIE1032H功能简介 XPCIE1032H是一款基于PCI Express的EtherCAT总线运动控制卡,可选6-64轴运动控制,支持多路高速数字输入输出,可轻松实现多轴同步控制和高速数据传输。 XPCIE1032H集成了强大的运动控制功能,结合MotionRT7运动…

深度解析Oladance、韶音、南卡开放式耳机:选购指南与天花板级推荐

​随着开放式耳机在日常生活中越来越受欢迎,许多品牌纷纷降低材料品质以迎合大众需求,导致耳机的性能和音质严重下滑。这让消费者在选择优质开放式耳机时感到困惑。作为一名专业的耳机评测人员,我近期对多款热门开放式耳机进行了深入的测评&a…

Leetcode—92.反转链表II【中等】

2023每日刷题(八十一) Leetcode—92.反转链表II 算法思想 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), n…

kubernetes Pod 异常排查步骤

kubernetes Pod 异常排查步骤 详细排查图查看容器状态查看容器列表容器未启动成功排查容器启动成功排查pod状态对应原因 详细排查图 查看容器状态 查看容器列表 查看容器列表,最好在后面跟上命名空间,不跟上查询出来是默认的 kubectl get pods -n kubesphere-system单独查看某…

【Spring 篇】深入探讨MyBatis映射文件中的动态SQL

MyBatis,这个名字在Java开发者的世界中犹如一道光芒,照亮着持久层操作的道路。而在MyBatis的映射文件中,动态SQL则是一个让人爱-hate的存在。有时候,你感叹它的灵活性,有时候,你可能会为它的繁琐而头痛。但…

windows 11安装VMware 17 ,VMware安装Ubuntu 20.4

一、下载安装激活VMware 17 下载与激活:Vmware 17 下载地址、最新激活码 2024 _ 注意:安装路径自己选择,路径中尽可能避免中文或空格 二、下载Ubuntu 镜像 下载镜像地址:清华大学开源软件镜像站 点开下载镜像地址,找…

中科星图——Sentinel-2_MSI_L2A数据集

数据名称: Sentinel-2_MSI_L2A 数据来源: Copernicus 时空范围: 2022年10月-2023年1月 空间范围: 全国 数据简介: 哨兵2号(Sentinel-2)卫星是高分辨率多光谱成像卫星,携带一…

分布式一致性算法---Raft初探

读Raft论文也有一段时间了,但是自己总是以目前并没有完全掌握为由拖着这篇博客。今天先以目前的理解程度(做了6.824的lab2A和lab2B)对这篇论文做一个初步总结,之后有了更深入的理解之后再进行迭代,关于本文有任何疑问欢…

苹果眼镜(Vision Pro)的开发者指南(3)-【3D UI SwiftUI和RealityKit】介绍

为了更深入地理解SwiftUI和RealityKit,建议你参加专注于SwiftUI场景类型的系列会议。这些会议将帮助你掌握如何在窗口、卷和空间中构建出色的用户界面。同时,了解Model 3D API将为你提供更多关于如何为应用添加深度和维度的知识。此外,通过学习RealityView渲染3D内容,你将能…

【Java数据结构 -- 队列:队列有关面试oj算法题】

队列、循环队列、用队列模拟栈、用栈模拟队列 1.队列1.1 什么是队列1.2 创建队列1.3 队列是否为空和获取队头元素 empty()peek()1.4 入队offer()1.5 出队(头删)poll() 2. 循环队列2.1 创建循环队列2.2 判断是否为空isEmpty()和满isFull()2.3 入队enQueue…

JAVA的面试题四

1.电商行业特点 (1)分布式: ①垂直拆分:根据功能模块进行拆分 ②水平拆分:根据业务层级进行拆分 (2)高并发: 用户单位时间内访问服务器数量,是电商行业中面临的主要问题 (3)集群&…

python数据分析——numpy基本用法

numpy数据类型 在NumPy中,有多种数据类型可用于表示数组的元素。以下是一些常见的NumPy数据类型: int - 整数类型,如int8、int16、int32、int64等。uint -无符号整数类型,如uint8、uint16、uint32、uint64等。float -浮点数类型…

PaddleNLP 如何打包成Windows环境可执行的exe?

当我们使用paddleNLP完成业务开发后,需要将PaddleNLP打包成在Windows操作系统上可执行的exe程序。操作流程: 1.环境准备: python环境:3.7.4 2.安装Pyinstaller pip install pyinstaller 3.目录结构,main.py为可执…

测试开发基础 | 计算机网络篇(二):物理层与数据链路层

【摘要】 计算机网络知识是自动化测试等技术基础,也是测试面试必考题目。霍格沃兹测试学院特别策划了本系列文章,将带大家一步步夯实计算机网络的基础知识。由于物理层知识在互联网软件研发工作中用到的并不多,所以可以仅做一个简单的了解。物…

jQuery语法知识(DOM操作)

一、class 属性: .addClass()、.hasClass().removeClass()、.toggleClass() 二、DOM 插入并包裹现有内容 1、.wrap( wrappingElement): 在每个配的元素外层包上一个html元素。 …

Buildroot显示kernel logo

buildroot开机时DSI屏幕变成跟uart一样输出log,现在想显示logo 1、failed to show loader logo [ 2.467479] mmcblk1: p1 p2 p3 p4 p5 p6 p7 p8 p9 [ 2.468827] rockchip-drm display-subsystem: cant not find any loader display [ 2.468859] rockc…

Windows Service 2008 r2的安装

创建虚拟机–(操作非常简单,跟着图片的数据下一步即可) 选择自己要安装的虚拟机版本 在这里可以更改虚拟机存放的位置 这里的40个G并不会马上占用,当虚拟机里的东西到40个G的大小就不会再存储东西了 选择和自己虚拟…

春运倒计时,AR 引领铁路运输安全新风向

根据中国交通新闻网发布最新消息,今年春运全国跨区域人员流动量预计达 90 亿人次。 随着春运期间旅客数量不断创下新高,铁路运输面临着空前的挑战与压力。 图源:pixabay 聚焦铁路运输效率与旅客安全保障问题,本期行业趋势将探讨 …