PPO 跑CartPole-v1

gym-0.26.2
cartPole-v1

参考动手学强化学习书中的代码,并做了一些修改

代码

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdmclass PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(nn.Module):def __init__(self, state_dim, hidden_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:"""PPO算法,采用截断的方式"""def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs    # 一条序列的数据用来训练轮数self.eps = eps  # PPO 中阶段范围的参数self.device = devicedef take_action(self, state):state = torch.FloatTensor([state]).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def gae(self, td_delta):td_delta = td_delta.detach().numpy()advantages_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = self.gamma * self.lmbda * advantage + deltaadvantages_list.append(advantage)advantages_list.reverse()return torch.FloatTensor(advantages_list)def update(self, transition_dist):states = torch.FloatTensor(transition_dist['states']).to(self.device)actions = torch.tensor(transition_dist['actions']).reshape((-1, 1)).to(self.device)rewards = torch.FloatTensor(transition_dist['rewards']).reshape((-1, 1)).to(self.device)next_states = torch.FloatTensor(transition_dist['next_states']).to(self.device)dones = torch.FloatTensor(transition_dist['dones']).reshape((-1, 1)).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)td_delta = td_target - self.critic(states)# GAE 计算广义优势advantage = self.gae(td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1, actions))ration = torch.exp(log_probs - old_log_probs)surr1 = ration * advantagesurr2 = torch.clamp(ration, 1-self.eps, 1+self.eps) * advantage # 截断actor_loss = torch.mean(-torch.min(surr1, surr2))   # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0))middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def train():actor_lr = 1e-3critic_lr = 1e-2num_episodes = 500hidden_dim = 128gamma = 0.98lmbda = 0.95epochs = 10eps = 0.2device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")env_name = "CartPole-v1"env = gym.make(env_name)torch.manual_seed(0)state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state, _ = env.reset()done, truncated = False, Falsewhile not done and not truncated:action = agent.take_action(state)next_state, reward, done, truncated, _ = env.step(action)done = done or truncated    # 这个地方要注意transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title(f'PPO on {env_name}')plt.show()mv_return = moving_average(return_list, 9)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title(f'PPO on {env_name}')plt.show()if __name__ == '__main__':train()

pycharm中运行结果:

效果看起很好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629478.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML--JavaScript--语法基础

变量与常量 这个基本上没啥问题 变量命名规则: 变量由字母、数字、下划线、$组成,且变量第一个字符不能为数字 变量不能是系统关键字和保留字 语法: var 变量名 值;所有Javacript变量都由var声明 定义赋值字符串: …

GaussDB(DWS)查询优化技术大揭秘

GaussDB(DWS)查询优化技术大揭秘 大数据时代,数据量呈爆发式增长,经常面临百亿、千亿数据查询场景,当数据仓库数据量较大、SQL语句执行效率低时,数据仓库性能会受到影响。本文将深入讲解在GaussDB(DWS)中如何进行表结构设计&#…

【Web】websocket应用的是哪个协议

🍎个人博客:个人主页 🏆个人专栏:Web ⛳️ 功不唐捐,玉汝于成 前言 在当今互联网时代,实时性和即时通讯成为网络应用日益重要的一部分。WebSocket 协议作为一种创新性的通信协议,极大地改善了…

C语言——编译和链接

(图片由AI生成) 0.前言 C语言是最受欢迎的编程语言之一,以其接近硬件的能力和高效性而闻名。理解C语言的编译和链接过程对于深入了解其运行原理至关重要。本文将详细介绍C语言的翻译环境和运行环境,重点关注编译和链接的各个阶段…

Architecture Lab:预备知识2【汇编call/leave/ret指令、CS:APP练习4.4】

chap4的练习4.4(page.255)让用Y86-64实现rsum(递归求数组元素之和),提示为:先得到x86-64汇编代码,然后转换成Y86-64的 这是rsum的c实现: long rsum(long *start, long count) {if …

【面试合集】说说微信小程序的发布流程?

面试官:说说微信小程序的发布流程? 一、背景 在中大型的公司里,人员的分工非常仔细,一般会有不同岗位角色的员工同时参与同一个小程序项目。为此,小程序平台设计了不同的权限管理使得项目管理者可以更加高效管理整个团…

微软推出付费版Copilot

关注卢松松,会经常给你分享一些我的经验和观点。 微软已经超越苹果,成了全球市值最高的公司,其他公司都因为AI大裁员,而微软正好相反,当然这个原因很简单:就是微软强制把AI全面接入到系统里来了。而Copilot…

Mac系统下,保姆级Jenkins自动化部署Android

一、Jenkins自动化部署 1、安装jenkins 官网:macOS Installers for Jenkins LTS 选择macOS brew install jenkins-lts 安装最新: brew install jenkins-lts 启动jenkins服务: brew services start jenkins-lts 重启jenkins服务: brew services restart jenkin…

web开发学习笔记(2.js)

1.引入 2.js的两种引入方式 3.输出语句 4.全等运算符 5.定义函数 6.数组 7.数组属性 8.字符串对象的对应方法 9.自定义对象 10.json对象 11.bom属性 12.window属性 13.定时刷新时间 14.跳转网址 15.DOM文档对象模型 16.获取DOM对象,根据DOM对象来操作网页 如下图…

基于杂交PSO算法的风光储微网日前优化调度(MATLAB实现)

微网中包含:风电、光伏、储能、微型燃气轮机,以最小化电网购电成本、光伏风机的维护成本、蓄电池充放电维护成本、燃气轮机运行成本及污染气体治理成本为目标,综合考虑:功率平衡约束、燃气轮机爬坡约束、电网交换功率约束、储能装…

【GCC】6 接收端实现:周期构造RTCP反馈包

基于m98代码。GCC涉及的代码,可能位于:webrtc/modules/remote_bitrate_estimator webrtc/modules/congestion_controller webrtc/modules/rtp_rtcp/source/rtcp_packet/transport_feedback.cc webrtc 之 RemoteEstimatorProxy 对 remote_bitrate_estimator 的 RemoteEstimato…

Vue 富文本实现内容项目倒序

应用场景: 比如写计划和待办事项,内容少还好,内容多了最新的内容就放在下面了,每次打开要滚动到最后才能看到,这时可以使用倒序把最新的排在最前面。 倒序前: 倒序后: 倒序代码: …

设计模式⑥ :访问数据结构

文章目录 一、前言二、Visitor 模式1. 介绍2. 应用3. 总结 三、Chain of Responsibility 模式1. 介绍2. 应用3. 总结 参考内容 一、前言 有时候不想动脑子,就懒得看源码又不像浪费时间所以会看看书,但是又记不住,所以决定开始写"抄书&q…

ElasticSearch概述+SpringBoot 集成ES

ES概述 开源的、高扩展的、分布式全文检索引擎【站内搜索】 解决问题 1.搜索词是一个整体时,不能拆分(mysql整体连续) 2.效率会低,不会用到索引(mysql索引失效) 解决方式 进行数据的存储(只存储…

【51单片机系列】继电器使用

文章来源:《零起点学Proteus单片机仿真技术》。 本文是关于继电器使用相关内容。 继电器广泛应用在工业控制中,通过继电器对其他大电流的电器进行控制。 继电器控制原理图如下。继电器部分包括控制线圈和3个引脚,A引脚接电源,B引…

排序算法9----计数排序(C)

计数排序是一种非比较排序,不比较大小 。 1、思想 计数排序又称为鸽巢原理,是对哈希直接定址法的变形应用。 2、步骤 1、统计数据:统计每个数据出现了多少次。(建立一个count数组,范围从[MIN,MAX],MAX代表arr中…

.Net 8.0 Web API Controllers 添加到 windows 服务

示例源码下载:https://download.csdn.net/download/hefeng_aspnet/88747022 创建 Windows 服务的方法之一是从工作线程服务模板开始。 但是,如果您希望能够让它托管 API 控制器(也许是为了查看它正在运行的进程的状态)&#xff0…

深入浅出Spring AOP

第1章:引言 大家好,我是小黑,咱们今天要聊的是Java中Spring框架的AOP(面向切面编程)。对于程序员来说,理解AOP对于掌握Spring框架来说是超级关键的。它像是魔法一样,能让咱们在不改变原有代码的…

git基础知识

简述 git 的安装配置、工作区域划分、文件类型、基本命令。 基础安装与配置 基于 WSL 的 Ubuntu 下的 git 打开或关闭Windows功能->Hyper-V、Virtual Machine Platform、Windows Subsystem for Linux # 1.必须运行 Windows 10 版本 2004 及更高版本(内部版本 …

matplotlib绘制动态瀑布图

绘制瀑布图思路:遍历指定文件目录下所有的csv文件,每读一个文件,取文件前20行数据进行保存,如果超过规定的行数300行,将最旧的数据删除,仅保留300行数据进行展示。 网上找的大部分绘制瀑布图的代码&#x…