SAC(Soft Actor-Critic)理论与代码解释

标题

  • 理论
    • 序言基础
      • Q值与V值
      • 算法区别
  • SAC
      • 概念
      • Q函数与V函数
      • 最大化熵强化学习(Maximum Entropy Reinforcement Learning, MERL)
      • 算法流程
        • 1个actor,4个Q Critic
        • 1个actor,2个V Critic,2个Q Critic
  • 代码详解

参考连接:SAC(Soft Actor-Critic)阅读笔记 - Feliks的文章 - 知乎

理论

序言基础

Q值与V值

在强化学习中,Critic网络可以采用Q值(动作值函数)或V值(状态值函数),具体选择取决于你使用的算法以及问题的特性。

  1. Q值(动作值函数): Critic网络输出每个状态动作对的Q值,表示在给定状态下采取某个动作的预期累积奖励。这种方法通常用于Q-learning和Deep Q Network(DQN)等算法中,其中主要关注最优动作的选择。

  2. V值(状态值函数): Critic网络输出每个状态的V值,表示在给定状态下的预期累积奖励。这种方法通常用于值迭代方法,如异策略(Off-policy)的蒙特卡洛控制和异策略时序差分学习。 V(s) 表示智能体在状态 s 下,从该状态开始直到未来所能获得的累积奖励的期望值。换句话说,它是智能体处于状态 s 时,遵循某种策略所带来的长期回报的估计。

选择Q值还是V值通常取决于你解决的问题。如果你关心在每个状态下选择最优动作,那么使用Q值更为合适。如果你更关心每个状态的价值,而不仅仅是最优动作的话,那么使用V值可能更合适。 Q ( s , a ) Q(s, a) Q(s,a) 表示智能体在状态 s 下执行动作 a 后,紧接着直到未来的累积奖励的期望值。与 V 值相比,Q 值不仅考虑了状态,还考虑了特定的动作选择。

在一些算法中,如深度确定性策略梯度(Deep Deterministic Policy Gradient,DDPG),使用的是一个Critic网络同时输出Q值和Actor网络的参数。这种情况下,Critic网络的输出可以同时用于评估状态动作对的Q值和评估状态的V值。

算法区别

D4PG(引入分布式的critic,并使用多个actor(learner)共同与环境交互)

TD3(参考了double Q-learning的思想来优化critic,延缓actor的更新,计算critic的优化目标时在action上加一个小扰动)

PPO:依赖于importance sampling实现的off-policy算法在面对太大的策略差异时将无能为力(正在训练的policy与实际与环境交互时的policy差异过大),所以学者们认为PPO其实是一种on-policy的算法,这类算法在训练时需要保证生成训练数据的policy与当前训练的policy一致,对于过往policy生成的数据难以再利用,所以在sample efficiency这条衡量强化学习(Reinforcement Learning, RL)算法的重要标准上难以取得优秀的表现。

SAC

概念

SAC是基于最大熵(maximum entropy)这一思想发展的RL算法,其采用与PPO类似的随机分布式策略函数(Stochastic Policy),并且是一个off-policy,actor-critic算法
在这里插入图片描述
将熵引入RL算法的好处为,可以让策略(policy)尽可能随机,agent可以更充分地探索状态空间,避免策略早早地落入局部最优点(local optimum),并且可以探索到多个可行方案来完成指定任务,提高抗干扰能力。

Q函数与V函数

在这里插入图片描述

最大化熵强化学习(Maximum Entropy Reinforcement Learning, MERL)

MERL采用了独特的策略模型。为了适应更复杂的任务,MERL中的策略不再是以往的高斯分布形式,而是用基于能量的模型(energy-based model)来表示策略:
在这里插入图片描述

算法流程

算法同样包括策略评估(Policy Evaluation),与策略优化(Policy Improvement),在这两个步骤交替运行下,值函数与策略都可以不断逼近最优。
在这里插入图片描述

1个actor,4个Q Critic

SAC的论文有两篇,一篇是《Soft Actor-Critic Algorithms and Applications》,2018年12月挂arXiv,其中SAC算法流程如下所示,它包括1个actor网络,4个Q Critic网络:(代码使用的是这个:Github链接)
在这里插入图片描述

1个actor,2个V Critic,2个Q Critic

一篇是《Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor》,2018年1月挂arXiv,其中SAC算法流程如下所示,它包括1个actor网络,2个V Critic网络(1个V Critic网络,1个Target V Critic网络),2个Q Critic网络:
参考知乎
在这里插入图片描述
在这里插入图片描述

代码详解

Actor网络

class Actor(nn.Module):def __init__(self, state_dim, action_dim, hidden_width, max_action):super(Actor, self).__init__()self.max_action = max_actionself.l1 = nn.Linear(state_dim, hidden_width)self.l2 = nn.Linear(hidden_width, hidden_width)self.mean_layer = nn.Linear(hidden_width, action_dim)self.log_std_layer = nn.Linear(hidden_width, action_dim)def forward(self, x, deterministic=False, with_logprob=True):x = F.relu(self.l1(x))x = F.relu(self.l2(x))mean = self.mean_layer(x)log_std = self.log_std_layer(x)  # We output the log_std to ensure that std=exp(log_std)>0log_std = torch.clamp(log_std, -20, 2)std = torch.exp(log_std)dist = Normal(mean, std)  # Generate a Gaussian distributionif deterministic:  # When evaluating,we use the deterministic policya = meanelse:a = dist.rsample()  # reparameterization trick: mean+std*N(0,1)if with_logprob:  # The method refers to Open AI Spinning up, which is more stable.log_pi = dist.log_prob(a).sum(dim=1, keepdim=True)log_pi -= (2 * (np.log(2) - a - F.softplus(-2 * a))).sum(dim=1, keepdim=True)else:log_pi = Nonea = self.max_action * torch.tanh(a)  # Use tanh to compress the unbounded Gaussian distribution into a bounded action interval.return a, log_pi

理论中的训练策略 π( ϕ \phi ϕ) 时的损失函数:

在这里插入图片描述
对应代码的:

        # Compute actor lossa, log_pi = self.actor(batch_s)Q1, Q2 = self.critic(batch_s, a)Q = torch.min(Q1, Q2)actor_loss = (self.alpha * log_pi - Q).mean()   ##这里就是关键了撒

Q函数训练时的损失函数:

在这里插入图片描述
对应代码:

        with torch.no_grad():batch_a_, log_pi_ = self.actor(batch_s_)  # a' from the current policy# Compute target Qtarget_Q1, target_Q2 = self.critic_target(batch_s_, batch_a_)target_Q = batch_r + self.GAMMA * (1 - batch_dw) * (torch.min(target_Q1, target_Q2) - self.alpha * log_pi_)# Compute current Qcurrent_Q1, current_Q2 = self.critic(batch_s, batch_a)# Compute critic losscritic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

温度系数的更新

在这里插入图片描述
H 0 \mathcal{H_0} H0 是预先定义好的最小策略熵的阈值。

        # Update alphaif self.adaptive_alpha:# We learn log_alpha instead of alpha to ensure that alpha=exp(log_alpha)>0alpha_loss = -(self.log_alpha.exp() * (log_pi + self.target_entropy).detach()).mean()self.alpha_optimizer.zero_grad()alpha_loss.backward()self.alpha_optimizer.step()self.alpha = self.log_alpha.exp()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Opencv——霍夫变换

霍夫直线变换 霍夫直线变换(Hough Line Transform)用来做直线检测 为了加升大家对霍夫直线的理解,我在左图左上角大了一个点,然后在右图中绘制出来经过这点可能的所有直线 绘制经过某点的所有直线的示例代码如下,这个代码可以直接拷贝运行 import cv2 as cv import matplot…

基于JavaWeb开发的服装网上商城系统【附源码】

基于JavaWeb开发的服装网上商城系统【附源码】 🍅 作者主页 央顺技术团队 🍅 欢迎点赞 👍 收藏 ⭐留言 📝 🍅 文末获取源码联系方式 📝 🍅 查看下方微信号获取联系方式 承接各种定制系统 &#…

如何优化博客的内容和用户体验

在当今数字时代,博客成为了分享知识、展示个人专业能力和吸引读者的重要工具。然而,随着越来越多的博客涌现,如何优化博客的内容和用户体验成为了一个关键的问题。本文将为你提供一些有效的技巧,帮助你优化博客的内容和提升用户体…

Phoncent博客,探索Rie Kudan的GPT创作之举

近日,大家都在谈论日本作家Rie Kudan,她凭借其小说《东京共鸣塔》("Tokyo-to Dojo-to")荣获了日本极具声望的芥川奖。这本小说引起了广泛的讨论和思考,因为令人惊讶的是,Kudan在其中直接引用了人…

报告发布 | 聚铭网络参编的《数据安全风险评估实务:问题剖析与解决思路》正式发布

近日,由中国信通院、中国通信标准化协会主办,中国通信标准化协会大数据技术标准推进委员会承办的“数据资产管理大会数据安全论坛”在北京成功召开。会上正式发布了《数据安全风险评估实务:问题剖析与解决思路》报告(以下简称“报…

事务、MVCC、锁

目录 事务MVCC锁 事务 四大特性:ACID 脏读:事务A读取到未提交事务B修改的数据 不可重复读:事务A修改了未提交事务B读取的数据 幻读:事务A增删了未提交事务B读取的数据 不可重复读与幻读都是读取的结果不同,前者侧重于…

什么是零知识证明?

Web3 的核心原则之一——透明度,也可能是其最大的缺点之一。没有人希望他们的所有在线活动(从金融交易到个人身份数据)都可供任何人公开查看。为了使区块链能够扩展并变得更容易访问,隐私必须成为首要任务。 零知识证明能够改变我…

一些著名的软件都用什么语言编写?

1、操作系统 Microsoft Windows :汇编 -> C -> C 备注:曾经在智能手机的操作系统(Windows Mobile)考虑掺点C#写的程序,比如软键盘,结果因为写出来的程序太慢,实在无法和别的模块合并&…

2001-2022年全国30省就业人数数据

2001-2022年全国30省就业人数数据 1、时间:2001-2022年 2、来源:各省年鉴、人口和就业年鉴、wind 3、指标:省份、年份、就业人数 4、范围:30个省市 5、缺失情况:无缺失 6、指标解释: 就业人口是指一…

[嵌入式软件][入门篇][仿真平台] STM32CubeMX的搭建

文章目录 一、简介二、STM32CubeMX的使用(1) 新建文件,芯片选型(2) sys设置和RCC设置(3) 配置时钟(4) 生成代码 三、仿真平台的使用 一、简介 STM32CubeMX是一种图形工具,通过分步过程可以非常轻松地配置STM32微控制器和微处理器,生成相应的初…

保护医疗数据不受威胁:MPLS专线在医疗网络安全中的角色

随着数字技术的快速发展,医疗行业正在经历一场革命。从电子健康记录到远程医疗服务,数字化不仅提高了效率,也带来了前所未有的挑战--尤其是关于数据安全和隐私保护的挑战。在这样的背景下,如何确保敏感的医疗数据安全传输&#xf…

第一口就喝到了珠珠

x*ay*bc;假设b杯比a杯大,那么就是往b中可以加入a杯,然后倒出b杯,就是求x和y的最大公因数,用cn1*k*xn2*k*ygcd(a,b)*(n1*xn2*y);而且c要小于a或者b的最大值. int gcd(int a, int b) {while (b) {int t b;b a % b;a t;}return s…

腾讯云Linux(OpenCloudOS)安装tomcat9(9.0.85)

腾讯云Linux(OpenCloudOS)安装tomcat9 下载并上传 tomcat官网 https://tomcat.apache.org/download-90.cgi 下载完成后上传至自己想要放置的目录下 解压文件 输入tar -xzvf apache-tomcat-9.0.85.tar.gz解压文件,建议将解压后的文件重新命名为tomcat,方便后期进…

大模型学习与实践笔记(十四)

使用 OpenCompass 评测 InternLM2-Chat-7B 模型使用 LMDeploy 0.2.0 部署后在 C-Eval 数据集上的性能 步骤1:下载internLM2-Chat-7B 模型,并进行挂载 以下命令将internlm2-7b模型挂载到当前目录下: ln -s /share/model_repos/internlm2-7b/ ./ 步骤2&…

音频分离软件有哪些?这些软件轻松分离

音频分离软件有哪些?随着音频处理需求的日益增长,音频分离软件成为了许多人的必备工具。为了满足这些需求,市面上涌现出了许多优秀的音频分离软件。本文将为您介绍5款知名的音频分离软件,让您轻松实现音频处理。 1.口袋视频转换器…

Linux ---- Shell编程之正则表达式

一、正则表达式 ​ 由一类特殊字符及文本字符所编写的模式,其中有些字符(元字符)不表示字符字面意义,而表示控制或通配的功能,类似于增强版的通配符功能,但与通配符不同,通配符功能是用…

Boost.Test-如何将测试套件(源码文件)组织成工程、并执行测试

Boost.Test资源及示例的续篇 1.测试套件TestSuite的源码文件组织如下图 2.CMakeLists.txt需要自己编写,本例内容如下 cmake_minimum_required(VERSION 3.5.0 FATAL_ERROR) project(mytestmodule) enable_testing()# indicates the location of the boost instal…

OAK深度相机主机时钟同步提升10倍!

编辑:OAK中国 首发:oakchina.cn 喜欢的话,请多多👍⭐️✍ 内容可能会不定期更新,官网内容都是最新的,请查看首发地址链接。 ▌前言 Hello,大家好,这里是OAK中国,我是Ash…

近期作业总结(函数,递归,二进制)

二分查找函数 写一个二分查找函数 功能&#xff1a;在一个升序数组中查找指定的数值&#xff0c;找到了就返回下标&#xff0c;找不到就返回-1。 int bin_search(int arr[], int left, int right, int key) {int mid 0;while (left < right) {mid (right left) / 2;if…

【pytest系列】- assert断言的使用

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…