A2C原理和代码实现

参考王树森《深度强化学习》课程和书籍


1、A2C原理:

在这里插入图片描述


Observe a transition: ( s t , a t , r t , s t + 1 ) (s_t,{a_t},r_t,s_{t+1}) (st,at,rt,st+1)

TD target:
y t = r t + γ ⋅ v ( s t + 1 ; w ) . y_{t} = r_{t}+\gamma\cdot v(s_{t+1};\mathbf{w}). yt=rt+γv(st+1;w).
TD error:
δ t = v ( s t ; w ) − y t . \quad\delta_t = v(s_t;\mathbf{w})-y_t. δt=v(st;w)yt.
Update the policy network (actor) by:
θ ← θ − β ⋅ δ t ⋅ ∂ ln ⁡ π ( a t ∣ s t ; θ ) ∂ θ . \mathbf{\theta}\leftarrow\mathbf{\theta}-\beta\cdot\delta_{t}\cdot\frac{\partial\ln\pi(a_{t}\mid s_{t};\mathbf{\theta})}{\partial \mathbf{\theta}}. θθβδtθlnπ(atst;θ).


def compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# torch.logical_not 对输入张量取逻辑非# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_loss

Update the value network (critic) by:
w ← w − α ⋅ δ t ⋅ ∂ v ( s t ; w ) ∂ w . \mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_{t}\cdot{\frac{\partial{v(s_{t}};\mathbf{w})}{\partial\mathbf{w}}}. wwαδtwv(st;w).


def compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 建议对比08_a2c.py,比较二者的差异。with torch.no_grad():value = self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * value[i]policy_loss = policy_loss.mean()return policy_loss

2、A2C完整代码实现:

参考后修改注释:最初的代码在https://github.com/wangshusen/DRL

"""8.3节A2C算法实现。"""
import argparse
import os
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categoricalclass ValueNet(nn.Module):def __init__(self, dim_state):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, 1)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)return xclass PolicyNet(nn.Module):def __init__(self, dim_state, num_action):super().__init__()self.fc1 = nn.Linear(dim_state, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, num_action)def forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)prob = F.softmax(x, dim=-1)return probclass A2C:def __init__(self, args):self.args = argsself.V = ValueNet(args.dim_state)self.V_target = ValueNet(args.dim_state)self.pi = PolicyNet(args.dim_state, args.num_action)self.V_target.load_state_dict(self.V.state_dict())def get_action(self, state):probs = self.pi(state)m = Categorical(probs)action = m.sample()logp_action = m.log_prob(action)return action, logp_actiondef compute_value_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算value loss。value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)return value_lossdef compute_policy_loss(self, bs, blogp_a, br, bd, bns):# 目标价值。with torch.no_grad():target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()# 计算policy loss。with torch.no_grad():advantage = target_value - self.V(bs).squeeze()policy_loss = 0for i, logp_a in enumerate(blogp_a):policy_loss += -logp_a * advantage[i]policy_loss = policy_loss.mean()return policy_lossdef soft_update(self, tau=0.01):def soft_update_(target, source, tau_=0.01):for target_param, param in zip(target.parameters(), source.parameters()):target_param.data.copy_(target_param.data * (1.0 - tau_) + param.data * tau_)soft_update_(self.V_target, self.V, tau)class Rollout:def __init__(self):self.state_lst = []self.action_lst = []self.logp_action_lst = []self.reward_lst = []self.done_lst = []self.next_state_lst = []def put(self, state, action, logp_action, reward, done, next_state):self.state_lst.append(state)self.action_lst.append(action)self.logp_action_lst.append(logp_action)self.reward_lst.append(reward)self.done_lst.append(done)self.next_state_lst.append(next_state)def tensor(self):bs = torch.as_tensor(self.state_lst).float()ba = torch.as_tensor(self.action_lst).float()blogp_a = self.logp_action_lstbr = torch.as_tensor(self.reward_lst).float()bd = torch.as_tensor(self.done_lst)bns = torch.as_tensor(self.next_state_lst).float()return bs, ba, blogp_a, br, bd, bnsclass INFO:def __init__(self):self.log = defaultdict(list)self.episode_length = 0self.episode_reward = 0self.max_episode_reward = -float("inf")def put(self, done, reward):if done is True:self.episode_length += 1self.episode_reward += rewardself.log["episode_length"].append(self.episode_length)self.log["episode_reward"].append(self.episode_reward)if self.episode_reward > self.max_episode_reward:self.max_episode_reward = self.episode_rewardself.episode_length = 0self.episode_reward = 0else:self.episode_length += 1self.episode_reward += rewarddef train(args, env, agent: A2C):V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)info = INFO()rollout = Rollout()state, _ = env.reset()for step in range(args.max_steps):action, logp_action = agent.get_action(torch.tensor(state).float())next_state, reward, terminated, truncated, _ = env.step(action.item())done = terminated or truncatedinfo.put(done, reward)rollout.put(state,action,logp_action,reward,done,next_state,)state = next_stateif done is True:# 模型训练。bs, ba, blogp_a, br, bd, bns = rollout.tensor()value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)V_optimizer.zero_grad()value_loss.backward(retain_graph=True)V_optimizer.step()policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)pi_optimizer.zero_grad()policy_loss.backward()pi_optimizer.step()agent.soft_update()# 打印信息。info.log["value_loss"].append(value_loss.item())info.log["policy_loss"].append(policy_loss.item())episode_reward = info.log["episode_reward"][-1]episode_length = info.log["episode_length"][-1]value_loss = info.log["value_loss"][-1]print(f"step={step}, reward={episode_reward:.0f}, length={episode_length}, max_reward={info.max_episode_reward}, value_loss={value_loss:.1e}")# 重置环境。state, _ = env.reset()rollout = Rollout()# 保存模型。if episode_reward == info.max_episode_reward:save_path = os.path.join(args.output_dir, "model.bin")torch.save(agent.pi.state_dict(), save_path)if step % 10000 == 0:plt.plot(info.log["value_loss"], label="value loss")plt.legend()plt.savefig(f"{args.output_dir}/value_loss.png", bbox_inches="tight")plt.close()plt.plot(info.log["episode_reward"])plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")plt.close()def eval(args, env, agent):agent = A2C(args)model_path = os.path.join(args.output_dir, "model.bin")agent.pi.load_state_dict(torch.load(model_path))episode_length = 0episode_reward = 0state, _ = env.reset()for i in range(5000):episode_length += 1action, _ = agent.get_action(torch.from_numpy(state))next_state, reward, terminated, truncated, info = env.step(action.item())done = terminated or truncatedepisode_reward += rewardstate = next_stateif done is True:print(f"episode reward={episode_reward}, length={episode_length}")state, _ = env.reset()episode_length = 0episode_reward = 0if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")parser.add_argument("--num_action", default=2, type=int, help="Number of action.")parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")parser.add_argument("--seed", default=42, type=int, help="Random seed.")parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")parser.add_argument("--do_train", action="store_true", help="Train policy.")parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")args = parser.parse_args()env = gym.make(args.env)agent = A2C(args)if args.do_train:train(args, env, agent)if args.do_eval:eval(args, env, agent)

3、torch.distributions.Categorical()

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs) # 用probs构造一个分布
action = m.sample() # 按照probs进行采样
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()

Probability distributions - torch.distributions — PyTorch 2.0 documentation

next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()


[Probability distributions - torch.distributions — PyTorch 2.0 documentation](https://pytorch.org/docs/stable/distributions.html)[【PyTorch】关于 log_prob(action) - 简书 (jianshu.com)](https://www.jianshu.com/p/06a5c47ee7c2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33157.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RabbitMQ】golang客户端教程5——使用topic交换器

topic交换器(主题交换器) 发送到topic交换器的消息不能具有随意的routing_key——它必须是单词列表,以点分隔。这些词可以是任何东西,但通常它们指定与消息相关的某些功能。一些有效的routing_key示例:“stock.usd.ny…

知识图谱基本工具Neo4j使用笔记 四 :使用csv文件批量导入图谱数据

文章目录 一、系统说明二、说明三、简单介绍1. 相关代码以及参数2. 简单示例 四、实际数据实践1. 前期准备(1) 创建一个用于测试的neo4j数据库(2)启动neo4j 查看数据库 2. 实践(1) OK 上面完成后&#xff0…

阿里云Linux服务器安装FTP站点全流程

阿里云百科分享使用阿里云服务器安装FTP全教程,vsftpd(very secure FTP daemon)是Linux下的一款小巧轻快、安全易用的FTP服务器软件。本教程介绍如何在Linux实例上安装并配置vsftpd。 目录 前提条件 步骤一:安装vsftpd 步骤二…

HTTP代理编程:Python实用技巧与代码实例

今天我要与大家分享一些关于HTTP代理编程的实用技巧和Python代码实例。作为一名HTTP代理产品供应商,希望通过这篇文章,帮助你们掌握一些高效且实用的编程技巧,提高开发和使用HTTP代理产品的能力。 一、使用Python的requests库发送HTTP请求&a…

无涯教程-Perl - mkdir函数

描述 此功能使用MODE指定的模式创建一个名称和路径EXPR的目录,为清楚起见,应将其作为八进制值提供。 语法 以下是此函数的简单语法- mkdir EXPR,MODE返回值 如果失败,此函数返回0,如果成功,则返回1。 例 以下是显示其基本用法的示例代码- #!/usr/bin/perl -w$dirname &…

Docker desktop使用配置

1. 下载安装 https://www.docker.com/ 官网下载并安装doker desktop 2. 配置镜像 (1)首先去阿里云网站上进行注册:https://cr.console.aliyun.com/cn-hangzhou/instances/mirrors (2)注册完成后搜索:容…

Jmeter入门之digest函数 jmeter字符串连接与登录串加密应用

登录请求中加密串是由多个子串连接,再加密之后传输。 参数连接:${var1}${var2}${var3} 加密函数:__digest (函数助手里如果没有该函数,请下载最新版本的jmeter5.0) 函数助手:Options > …

1.Fay-UE5数字人工程导入(UE数字人系统教程)

非常全面的数字人解决方案(含源码) Fay-UE5数字人工程导入 1、工程下载:xszyou/fay-ue5: 可对接fay数字人的ue5工程 (github.com) 2、ue5下载安装:Unreal Engine 5 3、ue5插件安装 依次安装以下几个插件 4、双击运行工程 5、切换中文 6、检…

JavaWeb学习|JavaBean;MVC三层架构;Filter;Listener

1.JavaBean 实体类 JavaBean有特定的写法: 必须要有一个无参构造 属性必须私有化。 必须有对应的get/set方法 用来和数据库的字段做映射 ORM; ORM:对象关系映射 表--->类 字段-->属性 行记录---->对象 2.<jsp&#xff1a;useBean 标签 3. MVC三层架构 4. Filter …

Mybatis 初识

目录 1. MyBatis入门 1.1 MyBatis的定义 1.2 MyBatis的核心 MyBatis的核心 JDBC 的操作回顾 1.3 MyBatis的执行流程 MyBatis基本工作原理 2. MyBatis的使用 2.1 MyBatis环境搭建 2.1.1 创建数据库和表 2.1.2 添加MyBatis框架支持 老项目添加MyBatis 新项目添加MyBatis 2.1.3 设…

考研算法38天:反序输出 【字符串的翻转】

题目 题目收获 很简单的一道题&#xff0c;但是还是有收获的&#xff0c;我发现我连scanf的字符串输入都忘记咋用了。。。。。我一开始写的 #include <iostream> #include <cstring> using namespace std;void deserve(string &str){int n str.size();int…

css小练习:案例6.炫彩加载

一.效果浏览图 二.实现思路 html部分 HTML 写了一个加载动画效果&#xff0c;使用了一个包含多个 <span> 元素的 <div> 元素&#xff0c;并为每个 <span> 元素设置了一个自定义属性 --i。 这段代码创建了一个简单的动态加载动画&#xff0c;由20个垂直排列的…

Flask实现接口mock,安装及使用教程(一)

1、什么是接口mock 主要是针对单元测试的应用&#xff0c;它可以很方便的解除单元测试中各种依赖&#xff0c;大大的降低了编写单元测试的难度 2、什么是mock server 正常情况下&#xff1a;测试客户端——测试——> 被测系统 ——依赖——>外部服务依赖 在被测系统和…

AI:01-基于机器学习的深度学习的玫瑰花种类的识别

文章目录 一、数据集介绍二、数据预处理三、模型构建四、模型训练五、模型评估六、模型训练七、模型评估八、总结深度学习技术在图像识别领域有着广泛的应用,其中一种应用就是玫瑰花种类的识别。在本文中,我们将介绍如何使用机器学习和深度学习技术来实现玫瑰花种类的识别,并…

运维监控学习1

1、监控对象&#xff1a; 1、监控对象的理解&#xff1b;CPU是怎么工作的&#xff1b; 2、监控对象的指标&#xff1a;CPU使用率&#xff1b;上下文切换&#xff1b; 3、确定性能基准线&#xff1a;CPU负载多少才算高&#xff1b; 2、监控范围&#xff1a; 1、硬件监控&#x…

“掌握类与对象,点亮编程之路“(下)

White graces&#xff1a;个人主页 &#x1f649;专栏推荐:《C语言入门知识》&#x1f649; &#x1f649; 内容推荐:“掌握类与对象&#xff0c;点亮编程之路“(上)&#x1f649; &#x1f439;今日诗词:春风得意马蹄疾&#xff0c;一日看尽长安花&#x1f439; 目录 &…

vscode里面报:‘xxx‘ is assigned a value but never used.解决办法

const setCurPage: React.Dispatch<React.SetStateAction<number>> 已声明“setCurPage”&#xff0c;但从未读取其值。ts(6133) setCurPage is assigned a value but never used.eslinttypescript-eslint/no-unused-vars 出现这个报错是eslint导致的&#xff0…

P450进阶款无人机室内定位功能研测

在以往的Prometheus 450&#xff08;P450&#xff09;无人机上&#xff0c;我们搭载的是Intel Realsense T265定位模块&#xff0c;使用USB连接方式挂载到机载计算机allspark上&#xff0c;通过机载上SDK驱动T265运行并输出SLAM信息&#xff0c;以此来实现室内定位功能。 为进…

倒数纪念日-生日提醒事项时间管理倒计时软件

倒数纪念日​​​​​​​是一款功能强大的时间管理、事项提醒软件。帮你更好的管理倒数日、纪念日、生日、节假日、还款日等各种重要日子&#xff0c;通知提醒&#xff0c;让你不再错过生命中的每一个重要日子。 【功能简介】 分类管理&#xff1a;倒数日、纪念日、自定义分类…

AJAX-笔记(持续更新中)

文章目录 Day1 Ajax入门1.AJAX概念和axios的使用2. 认识URL3.URL的查询参数4.常用的请求方法和数据提交5.HTTP协议-报文6.接口文档7.form-serialize插件8.案例用户登录 Day2 Ajax综合案bootstrap弹框图书管理图片上传更换背景个人信息设置 Day3 AJAX原理XMLHttpRequestPromise封…