【强化学习】17 ——DDPG(Deep Deterministic Policy Gradient)

文章目录

  • 前言
    • DDPG特点
  • 随机策略与确定性策略
  • DDPG:深度确定性策略梯度
    • 伪代码
    • 代码实践

前言

之前的章节介绍了基于策略梯度的算法 REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO。这类算法有一个共同的特点:它们都是在线策略算法,这意味着它们的样本效率(sample efficiency)比较低。本章将要介绍的深度确定性策略梯度(deep deterministic policy gradient,DDPG)算法通过使用离线的数据以及Belllman等式去学习 Q Q Q函数,并利用 Q Q Q函数去学习策略。

DDPG特点

  • DDPG是离线学习算法
  • DDPG可以在连续的动作空间中进行使用
  • Open AI Spinning Up 中的DDPG未实现并行运行

随机策略与确定性策略

首先来回顾一下随机策略与确定性策略相关内容

随机策略

  • 离散动作: π ( a ∣ s ; θ ) = exp ⁡ { Q θ ( s , a ) } ∑ a , exp ⁡ { Q θ ( s , a ′ ) } \pi(a|s;\theta)=\frac{\exp\{Q_\theta(s,a)\}}{\sum_a,\exp\{Q_\theta(s,a^{\prime})\}} π(as;θ)=a,exp{Qθ(s,a)}exp{Qθ(s,a)},学习出价值函数之后再求取相应的softmax分布
  • 连续动作: π ( a ∣ s ; θ ) ∝ exp ⁡ { ( a − μ θ ( s ) ) 2 } \pi(a|s;\theta)\propto\exp\left\{\left(a-\mu_\theta(s)\right)^2\right\} π(as;θ)exp{(aμθ(s))2},学习出的策略符合高斯分布(均值,方差)

确定性策略

  • 离散动作: π ( s ; θ ) = arg ⁡ max ⁡ a Q θ ( s , a ) \pi(s;\theta)=\arg\max_aQ_\theta(s,a) π(s;θ)=argmaxaQθ(s,a)策略不可微,但可以通过学习价值函数再求取argmax的方式得到相应的策略
  • 连续动作: a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ)策略可微,建立相应的函数映射,通过函数求导的方式进行策略学习

那么如何利用确定性策略学习连续动作呢?首先需要一个用于估计价值的Critic模块。 Q w ( s , a ) ≃ Q π ( s , a ) Q^w(s,a)\simeq Q^\pi(s,a) Qw(s,a)Qπ(s,a) L ( w ) = E s ∼ ρ π , a ∼ π θ [ ( Q w ( s , a ) − Q π ( s , a ) ) 2 ] L(w)=\mathbb{E}_{s\sim\rho^\pi,a\sim\pi_\theta}\left[\left(Q^w(s,a)-Q^\pi(s,a)\right)^2\right] L(w)=Esρπ,aπθ[(Qw(s,a)Qπ(s,a))2]

通过与环境的交互,可以获得状态的总体分布,又因为 a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ),因此可以利用链式法则进行求导。首先是 Q Q Q函数对 a a a进行求导( Q Q Q函数通常由网络学习出来,对 a a a向量进行求导相当于是调整相应的梯度以使得获得更大的 Q Q Q值),接着因为 a = π ( s ; θ ) a=\pi(s;\theta) a=π(s;θ),所以 a a a π \pi π进行求导。
J ( π θ ) = E s ∼ ρ π [ Q π ( s , a ) ] J(\pi_\theta)=\mathbb{E}_{s\sim\rho^\pi}[Q^\pi(s,a)] J(πθ)=Esρπ[Qπ(s,a)] ∇ θ J ( π θ ) = E s ∼ ρ π [ ∇ θ π θ ( s ) ∇ a Q π ( s , a ) ∣ a = π θ ( s ) ] \nabla_\theta J(\pi_\theta)=\mathbb{E}_{s\sim\rho^\pi}[\nabla_\theta\pi_\theta(s)\nabla_aQ^\pi(s,a)|_{a=\pi_\theta(s)}] θJ(πθ)=Esρπ[θπθ(s)aQπ(s,a)a=πθ(s)]

上式即为确定性策略梯度定理。确定性策略梯度定理的具体证明过程可参考《动手学强化学习》13.5 节。

DDPG:深度确定性策略梯度

在实际应用中,上述的带有神经函数近似器的actor-critic方法在面对有
挑战性的问题时是不稳定的。深度确定性策略梯度(DDPG)给出了在确定性策略梯度(DPG)基础上的解决方法:
• 经验重放(离线策略)
• 目标网络
• 在动作输入前标准化Q网络
• 添加连续噪声

下面我们来看一下 DDPG 算法的细节。DDPG 要用到4个神经网络,其中 Actor 和 Critic 各用一个网络,此外它们都各自有一个目标网络。DDPG 中 Actor 也需要目标网络因为目标网络也会被用来计算目标 Q Q Q值。DDPG 中目标网络的更新与 DQN 中略有不同:在 DQN 中,每隔一段时间将 Q Q Q网络直接复制给目标 Q Q Q网络;而在 DDPG 中,目标 Q Q Q网络的更新采取的是一种软更新(延时更新)的方式,即让目标 Q Q Q网络缓慢更新,逐渐接近网络,其公式为:
ω − ← τ ω + ( 1 − τ ) ω − \omega^-\leftarrow\tau\omega+(1-\tau)\omega^- ωτω+(1τ)ω

通常 τ \tau τ是一个比较小的数,当 τ = 1 \tau=1 τ=1时,就和 DQN 的更新方式一致了。而目标 μ \mu μ网络(策略网络)也使用这种软更新的方式。

另外,由于 Q Q Q函数存在 Q Q Q值过高估计的问题,DDPG 采用了 Double DQN 中的技术来更新 Q Q Q网络。但是,由于 DDPG 采用的是确定性策略,它本身的探索仍然十分有限。回忆一下 DQN 算法,它的探索主要由 ϵ \epsilon ϵ-贪婪策略的行为策略产生。同样作为一种离线策略的算法,DDPG 在行为策略上引入一个 N \mathcal{N} N随机噪声(原论文使用的是OU噪声,后来许多实验证明高斯噪声具有更好的效果)来进行探索。
在这里插入图片描述

OU噪声

伪代码

在这里插入图片描述

在这里插入图片描述

代码实践

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import utilclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)# action_bound是环境可以接受的动作最大值self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))return torch.tanh(self.fc2(x)) * self.action_boundclass QValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, s, a):# 拼接状态和动作cat = torch.cat([s, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class DDPG:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,action_bound, sigma, tau, buffer_size, minimal_size, batch_size, device, numOfEpisodes, env):self.action_dim = action_dimself.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 初始化目标价值网络并设置和价值网络相同的参数self.target_critic.load_state_dict(self.critic.state_dict())# 初始化目标策略网络并设置和策略相同的参数self.target_actor.load_state_dict(self.actor.state_dict())self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差,均值直接设为0self.tau = tau  # 目标网络软更新参数self.device = deviceself.env = envself.numOfEpisodes = numOfEpisodesself.buffer_size = buffer_sizeself.minimal_size = minimal_sizeself.batch_size = batch_sizedef take_action(self, state):state = torch.FloatTensor(np.array([state])).to(self.device)action = self.actor(state).item()# 给动作添加噪声,增加探索action = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(), net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)actions = torch.tensor(np.array(transition_dict['actions']), dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)q_targets = rewards + self.gamma * (self.target_critic(next_states, self.target_actor(next_states))) * (1 - terminateds + truncateds)critic_loss = torch.mean(F.mse_loss(q_targets, self.critic(states, actions)))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()self.soft_update(self.actor, self.target_actor)  # 软更新策略网络self.soft_update(self.critic, self.target_critic)  # 软更新价值网络def DDPGtrain(self):replay_buffer = util.ReplayBuffer(self.buffer_size)returnList = []for i in range(10):with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:for episode in range(int(self.numOfEpisodes / 10)):# initialize statestate, info = self.env.reset()terminated = Falsetruncated = FalseepisodeReward = 0# Loop for each step of episode:while (not terminated) or (not truncated):action = self.take_action(state)next_state, reward, terminated, truncated, info = self.env.step(action)replay_buffer.add(state, action, reward, next_state, terminated, truncated)state = next_stateepisodeReward += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > self.minimal_size:b_s, b_a, b_r, b_ns, b_te, b_tr = replay_buffer.sample(self.batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'terminateds': b_te,'truncateds': b_tr}self.update(transition_dict)if terminated or truncated:breakreturnList.append(episodeReward)if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (self.numOfEpisodes / 10 * i + episode + 1),'return':'%.3f' % np.mean(returnList[-10:])})pbar.update(1)return returnList

超参数设置参考:

    agent = DDPG(state_dim=env.observation_space.shape[0],hidden_dim=256,action_dim=env.action_space.shape[0],actor_lr=3e-4,critic_lr=3e-3,gamma=0.99,action_bound=env.action_space.high[0],sigma=0.01,tau=0.005,buffer_size=10000,minimal_size=1000,batch_size=64,device=device,numOfEpisodes=200,env=env)

在这里插入图片描述
DDPG算法相比之前的在线学习算法,更加稳定,同时收敛速度更快。

深度确定性策略梯度算法(DDPG),它是面向连续动作空间的深度确定性策略训练的典型算法。相比于它的先期工作,即确定性梯度算法(DPG),DDPG 加入了目标网络和软更新的方法,这对深度模型构建的价值网络和策略网络的稳定学习起到了关键的作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131023.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【踩坑及思考】浏览器存储 cookie 最大值超过 4kb,或 http 头 cookie 超过限制值

背景 本地生产环境:超过最大值 cookie token 不存储;客户生产环境:打开系统空白,且控制台报 http 400 错误; 出现了两种现象 现象一:浏览器对大于 4kb 的 cookie 值不存储 导致用户名密码登录&#xff…

解决问题 [Vue warn]: Missing required prop: “index“

vue项目控制台报错 [Vue warn]: Missing required prop: “index” 出现这个报错原因是<el-submenu></el-submenu>标签中缺少index属性&#xff0c;需要加上才能不报错 解决办法是&#xff1a; <el-submenu index""></el-submenu>

linux 下 物理迁移 mysql 数据库 不能启动问题

1、授权问题 # chown -R 777 /app/db/mysql 2、/etc/my.cnf配置问题 [mysqld] basedir/app/db/mysql datadir/app/db/mysql/data socket/app/db/mysql/mysql.sock.lock innodb_buffer_pool_size128M innodb_force_recovery 1 symbolic-links0 [mysqld_safe] log-error/app/…

linux驱动开发环境搭建

使用的是parallel 创建的ubuntu 16.04 ubuntu20.04虚拟机 源码准备 # 先查看本机版本 $ uname -r 5.15.0-86-generic# 搜索相关源码 $ sudo apt-cache search linux-source [sudo] password for showme: linux-source - Linux kernel source with Ubuntu patches linux-sourc…

笔记软件 Keep It mac v2.3.3中文版新增功能

Keep It mac是一款专为 Mac、iPad 和 iPhone 设计的笔记和信息管理应用程序。它允许用户在一个地方组织和管理他们的笔记、网络链接、PDF、图像和其他类型的内容。Keep It 还具有标记、搜索、突出显示、编辑和跨设备同步功能。 Keep It for mac更新日志 修复了更改注释或富文本…

Nacos-2.2.2源码修改集成高斯数据库GaussDB,postresql

一 &#xff0c;下载代码 Release 2.2.2 (Apr 11, 2023) alibaba/nacos GitHub 二&#xff0c; 执行打包 mvn -Prelease-nacos -Dmaven.test.skiptrue -Drat.skiptrue clean install -U 或 mvn -Prelease-nacos ‘-Dmaven.test.skiptrue’ ‘-Drat.skiptrue’ clean instal…

【H.264】RTP h264 码流 实例解析分析 3 : webrtc

【srs】SRS检测IBMF还是annexb 【H.264】RTP h264 码流 实例解析分析 2 : mediasoup收包 mediasoup 并没完整解析rtp包的内容,可能与mediasoup 只需要转发,不需要解码有关系。 webrtc 本身都是全的。 m98代码,先说关键: webrtc的VideoRtpDepacketizer 第一:对RTPVideoType…

成员变量为动态数据时不可轻易使用

问题描述 业务验收阶段&#xff0c;遇到了一个由于成员变量导致的线程问题 有一个kafka切面&#xff0c;用来处理某些功能在调用前后的发送消息&#xff0c;资产类型type是成员变量定义&#xff1b; 资产1类型推送消息是以zichan1为节点&#xff1b;资产2类型推送消息是以zi…

社区牛奶智能售货机为你带来便利与实惠

社区牛奶智能售货机为你带来便利与实惠 低成本&#xff1a;社区牛奶智能货机的最大优势在于成本低廉&#xff0c;租金和人工开支都很少。大部分时间&#xff0c;货柜都是由无人操作来完成销售任务。 购买便利&#xff1a;社区居民只需通过手机扫码支付&#xff0c;支付后即可自…

二、计算机组成原理与体系结构

&#xff08;一&#xff09;数据的表示 不同进制之间的转换 R 进制转十进制使用按权展开法&#xff0c;其具体操作方式为&#xff1a;将 R 进制数的每一位数值用 Rk 形式表示&#xff0c;即幂的底数是 R &#xff0c;指数为 k &#xff0c;k 与该位和小数点之间的距离有关。当…

【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)

目录 0. 前言 1. Cifar10数据集 2. AlexNet网络模型 2.1 AlexNet的网络结构 2.2 激活函数ReLu 2.3 Dropout方法 2.4 数据增强 3. 使用GPU加速进行批量训练 4. 网络模型构建 5. 训练过程 6. 完整代码 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我…

[开源]企业级在线办公系统,基于实时音视频完成在线视频会议功能

一、开源项目简介 企业级在线办公系统 本项目使用了SpringBootMybatisSpringMVC框架&#xff0c;技术功能点应用了WebSocket、Redis、Activiti7工作流引擎&#xff0c; 基于TRTC腾讯实时音视频完成在线视频会议功能。 二、开源协议 使用GPL-3.0开源协议 三、界面展示 部分…

2024天津理工大学中环信息学院专升本机械设计制造自动化专业考纲

2024年天津理工大学中环信息学院高职升本科《机械设计制造及其自动化》专业课考试大纲《机械设计》《机械制图》 《机械设计》考试大纲 教 材&#xff1a;《机械设计》&#xff08;第十版&#xff09;&#xff0c;高等教育出版社&#xff0c;濮良贵、陈国定、吴立言主编&#…

1.如何实现统一的API前缀-web组件篇

文章目录 1. 问题的由来2.实现原理3. 总结 1. 问题的由来 系统提供了 2 种类型的用户&#xff0c;分别满足对应的管理后台、用户 App 场景。 两种场景的前缀不同&#xff0c;分别为/admin-api/和/app-api/&#xff0c;都写在一个controller里面&#xff0c;显然比较混乱。分开…

AI:60-基于深度学习的瓜果蔬菜分类识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

网络基础扫盲-多路转发

博客内容&#xff1a;多路转发的常见方式select&#xff0c;poll&#xff0c;epoll 文章目录 一、五种IO模型二、多路转发的常见接口1.select2、poll3、epoll 总结 前言 Linux下一切皆文件&#xff0c;是文件就会存在IO的情况&#xff0c;IO的方式决定了效率的高低。 一、五种…

基于java+springboot+vue在线选课系统

项目介绍 本系统结合计算机系统的结构、概念、模型、原理、方法&#xff0c;在计算机各种优势的情况下&#xff0c;采用JAVA语言&#xff0c;结合SpringBoot框架与Vue框架以及MYSQL数据库设计并实现的。员工管理系统主要包括个人中心、课程管理、专业管理、院系信息管理、学生…

Cube MX 开发高精度电流源跳坑过程/SPI连接ADS1255/1256系列问题总结/STM32 硬件SPI开发过程

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 1.使用STM32F系列开发一款高精度恒流电源&#xff0c;用到了24位高精度采样芯片ADS1255/ADS1256系列。 2.使用时发现很多的坑&#xff0c;详细介绍了每个坑的具体情况和实际的解决办法。 坑1&#xff1a;波特率设置…

如何使用Ruby 多线程爬取数据

现在比较主流的爬虫应该是用python&#xff0c;之前也写了很多关于python的文章。今天在这里我们主要说说ruby。我觉得ruby也是ok的&#xff0c;我试试看写了一个爬虫的小程序&#xff0c;并作出相应的解析。 Ruby中实现网页抓取&#xff0c;一般用的是mechanize&#xff0c;使…

Pytorch从零开始实战08

Pytorch从零开始实战——YOLOv5-C3模块实现 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——YOLOv5-C3模块实现环境准备数据集模型选择开始训练可视化模型预测总结 环境准备 本文基于Jupyter notebook&#xff0c;使用Python3.8&#xff0c…