【PPO】近端策略优化【Clip版本,离散动作】

本博客代码参考了《动手学强化学习-PPO》

PPO算法是在Actor-Critic的基础上进行训练目标的调整。其改进的地方在于对每次参数更新进行了限制。

PPO 是 TRPO 的一种改进算法,它在实现上简化了 TRPO 中的复杂计算,并且它在实验中的性能大多数情况下会比 TRPO 更好,因此目前常被用作一种常用的基准算法。需要注意的是,TRPO 和 PPO 都属于在线策略学习算法,即使优化目标中包含重要性采样的过程,但其只是用到了上一轮策略的数据,而不是过去所有策略的数据。


文章目录

  • PPO
    • 策略网络定义
    • 价值网络定义
    • PPO算法整体结构
    • 参数更新
      • 截断
      • 重要性采样
  • 在线训练
  • 一些概念


PPO

策略网络定义

import torch.nn as nn
import torch.nn.functional as Fclass PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self,x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x),dim=1)

可以看到PPO算法中的策略网络输出的是对应的动作概率,从最后一行代码中可以识别出F.softmax(self.fc2(x),dim=1)

价值网络定义

import torch.nn as nn
import torch.nn.functional as Fclass ValueNet(nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, 1)def forward(self,x):x = F.relu(self.fc1(x))return self.fc2(x)

PPO算法与Actor-Critic算法一样,其价值网络仅用来进行价值估计。所以其最后的输出的维度是一维,从self.fc2 = nn.Linear(hidden_dim, 1)可以看出

PPO算法整体结构

class PPO:def __init__(self, state_dim, hidden_dim, action_dim,actor_lr, critic_lr, gamma, lmbda, epochs, eps, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim)self.critic = ValueNet(state_dim, hidden_dim)self.actor_optimizer = torch.optim.Adam(lr = actor_lr)self.critic_optimizer = torch.optim.Adam(lr = critic_lr)self.device = deviceself.eps = epsself.lmbda = lmbdaself.epochsdef take_action(state):state = torch.tensor([state].dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(x)action = action_dist.sample()return action.item()def update(self,transition_dict):states = torch.tensor(transition_dict['satats'],dtype=torch.float).to(self.device)actions torch.tensor(transition_dict['actions']).view(-1,1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1,1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor[transition_dict['dones'],dtype=torch.float).view(-1,1).to(self.device)# 计算td targettd_target = rewards + self.gamma*self.critic(next_states)td_delta = td_target - self.critic(states)advantages = rf_utils(self.gamma, self.lmbda, td_delta.cpu()).to(self.device))old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach() # 参数更新之后,就不是当前策略的函数了,所以要进行detach()进行梯度截断for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1,actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio*advantagessurr2 = torch.clamp(ratio, 1 -self.eps, 1+self.eps)*advantagesactor_loss = torch.mean(-torch.min(surr1, surr2))critic_loss = torch.mean(torch.mse_loss(self.critic(states), td_target.detach()))self.criti_optimizer.zero_grad()self.actor_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.critic_optimizer.step()self.actor_optimizer.step()

参数更新

PPO算法相比于Actor-Critic在参数更新部分有两个重要的调整:1. 近端策略优化; 2. 重要性采样

截断

PPO 的另一种形式 PPO-截断(PPO-Clip)更加直接,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大,即:
arg ⁡ max ⁡ θ E s ∼ ν E a ∼ π θ k ( ⋅ ∣ s ) [ min ⁡ ( π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) , clip ⁡ ( π θ ( a ∣ s ) π θ k ( a ∣ s ) , 1 − ϵ , 1 + ϵ ) A π θ k ( s , a ) ) ] \arg\max_{\theta}\mathbb{E}_{s\sim\nu}\mathbb{E}_{a\sim\pi_{\theta_{k}}(\cdot|s)}\left[\min\left(\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)}A^{\pi_{\theta_{k}}}(s,a),\operatorname{clip}\left(\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)},1-\epsilon,1+\epsilon\right)A^{\pi_{\theta_{k}}}(s,a)\right)\right] argmaxθEsνEaπθk(s)[min(πθk(as)πθ(as)Aπθk(s,a),clip(πθk(as)πθ(as),1ϵ,1+ϵ)Aπθk(s,a))]
其中 clip ⁡ ( x , l , r ) : = max ⁡ ( min ⁡ ( x , r ) , l ) \operatorname{clip}(x,l,r):=\max(\min(x,r),l) clip(x,l,r):=max(min(x,r),l) ,即把 x \text{x} x 限制在 [ l , r ] [l,r] [l,r] 内。上式中是一个超参数,表示进行截断(clip)的范围。

如果 A π θ k ( s , a ) > 0 A^{\pi_{\theta_{k}}}(s,a)>0 Aπθk(s,a)>0,说明这个动作的价值高于平均,最大化这个式子会增大 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)} πθk(as)πθ(as),但不会让其超过 1 + ϵ 1+\epsilon 1+ϵ。反之,如果 A π θ k ( s , a ) < 0 A^{\pi_{\theta_{k}}}(s,a)<0 Aπθk(s,a)<0,最大化这个式子会减小 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{k}}(a|s)} πθk(as)πθ(as) ,但不会让其超过$1-\epsilon $。如下图所示

在这里插入图片描述
代码surr2 = torch,clamp(ratio, 1- self.eps, 1+ self.eps)*advantages 和 代码actor_loss = torch.mean(- torch.min(surr1, surr2))提现了截断的思想

重要性采样

PPO中使用了重要性采样,从而一定程度上缓解了样本使用效率低的问题,提高了单次样本参与模型参数训练的次数。重要性采样的公式推导,与蒙特卡洛的近似分布有关。尽管重要性采样允许PPO算法使用单轮样本进行多次训练,但是PPO算法更偏向于 on-policy。参考博客
代码ratio = torch.exp(log_probs - old_log_probs)其中的ratio就是重要性采样因子,通过该因子,从而调整出新策论与旧策略的分布差异。

在线训练

def train_on_policy_agent(agent, env, num_episodes):return_list = []for i in range(10):with tqdm(total = int(num_episodes/10),desc='Iteration %d':% i) as pbar:for i_episode in range(int(episode/10)):state = env.reset()transition_dict = {'states':[],'actions':[],'next_state'=[],'rewards':[],'dones':[]}episode_return = 0done = False while not done:action = agent.take_action(state)next_state,reward, done, _ = agent.take_action(state)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i+i_episode)%10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})bar.update(1)return return_list

由于PPOon-policy类型的强化学习,所以训练PPO参数需要使用在线策略学习方式,即每次与环境交互一轮之后,就要根据交互收集到的轨迹transition进行一次参数更新,体现在代码agent.update(transition_dict

一些概念

在这里插入图片描述

  • Q-learning、DQN 及 DQN 改进算法都是基于价值(value-based), 他们通过选择最大价值动作来与环境交互

  • Actor-Critic 结合了两者的特点,Reinforce 作为 Actor 部分, DQN 作为 Critic 部分,从而结合了两者的有点

  • Reinforce、Actor-Critic 通过 SARSA 样本数据进行训练,所以他们是 on-policy基于策略,也就是学习一个策略,从策略中进行动作概率分布抽样

  • REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新, Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

  • TRPO 是在 Actor-Critic 的基础上加入了更新幅度限,也就是制信任区域(trust region),从而避免模型效果的震荡。

  • PPO 是 TRPO 的改良版, 基于 TRPO 的思想,但是 PPO算法实现更加简单,没有TRPO 的计算那么复杂和远算量那么大。

  • PPO 有两种形式,一是 PPO-惩罚,二是 PPO-截断,

  • PPO-截断总是比 PPO-惩罚表现得更好, 大量实验表明。

  • REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO, 这类算法有一个共同的特点:它们都是在线策略算法,这意味着它们的样本效率(sample efficiency)比较低。

  • TRPO(trust region policy optimization,TRPO)。当策略网络是深度模型时,沿着策略梯度更新参数,很有可能由于步长太长,策略突然显著变差,进而影响训练效果。针对这个问题,考虑在更新时找到一块信任区域(trust region),在这个区域上更新策略时能够得到某种策略性能的安全性保证,这就是信任区域策略优化(trust region policy optimization,TRPO)算法的主要思想。TRPO 算法在 2015 年被提出,它在理论上能够保证策略学习的性能单调性,并在实际应用中取得了比策略梯度算法更好的效果。

  • Policy Base 与 Value Base

    • Policy-Base (基于策略) 通过环境,直接输出下一步要采取的各种动作的概率,然后根据概率采取行动,所以每种动作都有可能被选中,只是可能性不同。如,Policy Gradients等。
    • Value-Based RL(基于价值)
      输出所有动作的价值,根据最高价值来选择动作。如,Q learning 、DQN等。(对于不连续的动作,这两种方法都可行,但如果是连续的动作基于价值的方法是不能用的,我们只能用一个概率分布在连续动作中选择特定的动作)。
    • Actor-Critic (主要基于策略,价值作为辅助)
      结合这两种方法建立一种Actor-Critic的方法,基于概率会给出做出的动作,基于价值会对做出的动作进行评分,是二者的综合,但更偏向于策略模型领域。
  • on-policy 与 off-policy 的异同点。 无论是在线策略(on-policy)算法还是离线策略(off-policy)算法,都有一个共同点:智能体在训练过程中可以不断和环境交互,得到新的反馈数据。二者的区别主要在于在线策略算法会直接使用这些反馈数据,而离线策略算法会先将数据存入经验回放池中,需要时再采样

  • Model-Based和Model-Free 是关于如何对环境建模和学习的方法。Model-Based建立模型,Model-Free直接学习策略或价值函数。

  • Q-learning 与 Sarsa

    • 为什么 Q learning 能够使用Exprience replay 而 Sarsa 不能使用Exprience replay。 Q-learning的目标是求解“真正”的 Q ∗ ( s , a ) Q^{*}(s,a) Q(s,a) ,而Sarsa的目标则是求解 Q π ( s , a ) Q_\pi(s,a) Qπ(s,a) 。在Q-learning中,我们一般会采用experience replay技术,即准备一个数据库并不断把Agent新产生的 ( s , a , r , s ′ ) (s,a,r,s^{\prime}) (s,a,r,s) 数据集存入数据库中。我们每次会从数据库中随机抽取一个batch的数据集用以训练,这意味着每次训练时我们用到的数据集可能是Agent在很久以前产生的。但是,无论我们用到的数据是Agent在训练中的哪一个阶段产生的,数据都是服从环境分布的,所以它们当然都可以被用以训练。
    • Sarsa 类型为什么是 on-policy 在线策略学习。 在Sarsa中,情况则与Q-learning很不一样。对于 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 的训练数据集,我们不但要求 ( r , s ′ ) (r,s^{\prime}) (r,s) 应该服从环境分布,也要求 a ′ a^{\prime} a 必须服从 π \pi π 关于 s ′ s^{\prime} s 的条件分布。在训练中,Q表的内容会不断被改变,所以Agent产生数据的策略 也会不断被改变。这意味在Agent过去产生的 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 中, ( s ′ , a ′ ) (s^{\prime},a^{\prime}) (s,a) 可能不服从现在策略 π \pi π 对应的条件分布,因此Agent在过去产生的数据就不能用以现在的训练。 由于上述的原因,我们不能在Sarsa中采用experience replay。在训练中,设当前Agent产生数据的策略为 π \pi π 。我们可以一次性用Agent产生大量服从环境及 π \pi π 分布的数据,并用这些数据来进行训练。而训练过后,Q表的内容发生了变化,这意味着Agent产生数据的策略变成了与 π \pi π 不同的 π ′ \pi^{\prime} π 。这时,刚才那些服从环境与 π \pi π 分布的 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 数据就变得不再有价值,我们只能将其丢弃。接下来,我们就要让Agent用当前产生数据的策略 继续产生大量的数据,并进行下一步的训练。
  • off-line (离线学习)的应用场景。 在现实生活中的许多场景下,让尚未学习好的智能体和环境交互可能会导致危险发生,或是造成巨大损失。例如,在训练自动驾驶的规控智能体时,如果让智能体从零开始和真实环境进行交互,那么在训练的最初阶段,它操控的汽车无疑会横冲直撞,造成各种事故。再例如,在推荐系统中,用户的反馈往往比较滞后,统计智能体策略的回报需要很长时间。而如果策略存在问题,早期的用户体验不佳,就会导致用户流失等后果。因此,离线强化学习(offline reinforcement learning)的目标是,在智能体不和环境交互的情况下,仅从已经收集好的确定的数据集中,通过强化学习算法得到比较好的策略。

  • Reinforce 基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。这是因为没有价值网络,所以他无法对当下状态的价值进行直接估计,只能通过蒙特卡洛采样进行逆向奖励累积。而Actor-Critic 由于有价值网络,所以可以直接通过计算前后状态和奖励的差值来计算时序误差。当然Actor-Critic也可以采用蒙特卡洛采样,不过这种方法需要更多的计算量


https://blog.csdn.net/qq_45889056/article/details/130297960
https://zhuanlan.zhihu.com/p/166412379
https://blog.csdn.net/qq_45889056/article/details/130297960

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/747452.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务模块划分规范

一、PO :(persistant object )&#xff0c;持久对象 可以看成是与数据库中的表相映射的java对象。使用Hibernate来生成PO是不错的选择。 二、VO :(value object) &#xff0c;值对象 通常用于业务层之间的数据传递&#xff0c;和PO一样也是仅仅包含数据而已。但应是抽象出的…

功能问题:如何用Docker部署一个后端项目?

大家好&#xff0c;我是大澈&#xff01; 本文约1800字&#xff0c;整篇阅读大约需要3分钟。 关注微信公众号&#xff1a;“程序员大澈”&#xff0c;免费加入问答群&#xff0c;一起交流技术难题与未来&#xff01; 现在关注公众号&#xff0c;免费送你 ”前后端入行大礼包…

SwiftU的组件 - TabView

SwiftU的组件 - TabView 记录一下SwiftU的组件 - TabView的两种style分别的使用方式 import SwiftUIstruct TabViewBootCamp: View {State var selectedIndex 0var body: some View {NavigationView {TabView(selection: $selectedIndex) {HomeView(selectedIndex: $selected…

基于python的《彩图版飞机大战》程序使用说明(附源码下载)

在PyCharm中运行《彩图版飞机大战》即可进入如图1所示的游戏界面。 图1 游戏主界面 具体的操作步骤如下&#xff1a; &#xff08;1&#xff09;玩游戏。在游戏主界面中&#xff0c;从屏幕的顶部不断出现下落的敌机&#xff0c;玩家按下键盘上的↑、↓、←、→方向键移动飞机…

Android 深入Http(2)加密与编码

可以对二进制数据&#xff08;比如图片、视频&#xff09; 经典算法&#xff1a; DES&#xff08;密钥短被弃用了&#xff09; AES &#xff08;密钥很长 很顶&#xff09; 速度快&#xff0c;效率高 IDEA 3DES&#xff08;三重DES&#xff0c;听起来就很慢和重 &#xf…

VGG论文学习笔记

题目&#xff1a;VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION 论文下载地址&#xff1a;VGG论文 摘要 目的&#xff1a;研究深度对精度的影响 方法&#xff1a;使用3*3滤波器不断增加深度&#xff0c;16和19效果显著 成绩&#xff1a;在ImageNet 20…

搭建知识管理系统并不复杂,这篇教程来帮你

许多人都有这样的体验&#xff1a;我们抓住的想法和知识总在不经意间溜走&#xff0c;我们想要的信息总是一时无法找到。因此&#xff0c;搭建一个能够系统化、分类和索引存储这些知识的“知识管理系统”是必要的。听上去很专业&#xff0c;其实并不复杂&#xff0c;让我们一步…

mysql: 如何开启慢查询日志?

1 确认慢查询日志功能已开启 执行以下sql语句&#xff0c;查看慢查询功能是否开启&#xff1a; show VARIABLES like slow_query_log;如果为ON&#xff0c;表示打开&#xff1b;如果为OFF&#xff0c;表示没有打开&#xff0c;需要开启慢查询功能。 执行以下sql语句&#xff0…

修改 MySQL update_time 默认值的坑

由于按规范需要对 update_time 字段需要对它做默认值的设置 现在有一个原始的表是这样的 CREATE TABLE test_up (id bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 主键id,update_time datetime default null COMMENT 操作时间,PRIMARY KEY (id) ) ENGINEInnoDB DEF…

MapStruct代替BeanUtils.copyProperties ()使用

1.为什么MapStruct代替BeanUtils.copyProperties () 第一&#xff1a;因为BeanUtils 采用反射的机制动态去进行拷贝映射&#xff0c;特别是Apache的BeanUtils的性能很差&#xff0c;而且并不支持所有数据类型的拷贝&#xff0c;虽然使用较为方便&#xff0c;但是强烈不建议使用…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:NavRouter)

导航组件&#xff0c;默认提供点击响应处理&#xff0c;不需要开发者自定义点击事件逻辑。 说明&#xff1a; 该组件从API Version 9开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 必须包含两个子组件&#xff0c;其中第二个子组…

分析型数据库的主要使用场景有哪些?

如今数据已经成为了企业和组织的核心资产。如何有效地管理和利用这些数据&#xff0c;成为了决定竞争力的关键。分析型数据库作为数据处理领域的重要工具&#xff0c;为各行各业提供了强大的数据分析和洞察能力。基于分析型数据库&#xff08;Apache Doris &#xff09;构建的现…

当模型足够大时,Bias项不会有什么特别的作用

问题来源&#xff1a; 阅读OLMo论文时&#xff0c;发现有如下一段话&#xff1a; 加上前面研究llama和mistral结构时好奇为什么都没有偏置项了 偏置项的作用&#xff1a; 回到第一性原理来分析&#xff0c;为什么要有偏置项的存在呢&#xff1f; 在神经网络中&#xff0c;…

跨境热点!TikTok直播网络要求是什么?

TikTok直播作为一种互动性强、实时性要求高的社交媒体形式&#xff0c;对网络环境有着一系列特定的需求。了解并满足这些需求&#xff0c;对于确保用户体验、提高直播质量至关重要。本文将深入探讨TikTok直播对网络环境的要求以及如何优化网络设置以满足这些要求。 TikTok直播的…

mac启动elasticsearch

1.首先下载软件&#xff0c;然后双击解压&#xff0c;我用的是7.17.3的版本 2.然后执行如下命令 Last login: Thu Mar 14 23:14:44 on ttys001 diannao1xiejiandeMacBook-Air ~ % cd /Users/xiejian/local/software/elasticsearch/elasticsearch-7.17.3 diannao1xiejiandeMac…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Menu)

以垂直列表形式显示的菜单。 说明&#xff1a; 该组件从API Version 9开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 Menu组件需和bindMenu或bindContextMenu方法配合使用&#xff0c;不支持作为普通组件单独使用。 子组件 包含MenuIt…

HTML—CSS盒子模型(Box Model)

基本介绍&#xff1a; CSS处理网页时&#xff0c;HTML的每一个标签可以看作是一个盒子&#xff0c;网页布局将指定的标签放到指定的位置上摆放&#xff0c;相当于摆放盒子。 每一个标签(盒子)所包含的内容&#xff1a;从外到内 ①外边距(margin)—规定盒子与盒子之间的距离&…

LeetCode---388周赛

题目列表 3074. 重新分装苹果 3075. 幸福值最大化的选择方案 3076. 数组中的最短非公共子字符串 3077. K 个不相交子数组的最大能量值 一、重新分装苹果 注意题目中说同一个包裹中的苹果可以分装&#xff0c;那么我们只要关心苹果的总量即可&#xff0c;在根据贪心&#x…

为什么光学器件需要厚度

确定光学厚度的限值 光学元件的功能和性能在很大程度上受到可用光学材料的限制。制造和光学元件设计的最新发展现在拓宽了可以实现的目标。特别是&#xff0c;平面光学器件或超表面可以设计为具有大块光学元件的功能&#xff0c;但其厚度缩小到仅几百纳米。米勒现在提出了一项…

git小白入门

git是什么 Git是一种流行的版本控制系统&#xff0c;被广泛用于软件开发中来跟踪和管理代码的变化。它是由Linus Torvalds在2005年创建的&#xff0c;最初的目的是为了更高效地管理Linux内核的开发。Git使得多人在同一个项目上工作变得更加简单&#xff0c;可以轻松合并不同开…