reinforce 跑 CartPole-v1

gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。

然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。

代码

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根据动作概率分布随机采样state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):  # 公式用的是简化推导reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))):  # 从最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G  # 因为梯度更新是减的,所以取个负号loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :  # 主要是这部分和原始的有点不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,结果如下所示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/212418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

innobackupex备份目录

innobackupeex全备脚本思路 四个需求如下: (1)每天晚上23点执行,这需要linux系统做一个定时任务 00 23 * * * /bin/sh /shell/tencent_xtrabackup_all.sh /dev/null 2>&1 (2)每天。。看到这个词…

标识符···

定义 标识符只能由字母、数字、下划线(_)和美元符号($)组成。标识符必须以字母、下划线或美元符号开头,不能以数字开头。标识符对大小写敏感,例如"myVariable"和"myvariable"是不同的…

Android 11 适配——整理总结篇

背景 > 经过检测,我们识别到您的应用,目前未适配安卓11(API30),请您关注适配截止时间,尽快开展适配工作,避免影响应用正常发布和经营。 > targetSdkVersion30 升级适配工作参考文档&am…

从零开发短视频电商 Jmeter压测示例模板详解(无认证场景)

文章目录 添加线程组添加定时器添加HTTP请求默认值添加HTTP头管理添加HTTP请求添加结果断言响应断言 Response AssertionJSON断言 JSON Assertion持续时间断言 Duration Assertion 添加察看结果树添加聚合报告添加表格察看结果参考 以压测百度搜索为例 https://www.baidu.com/s…

class066 一维动态规划【算法】

class066 一维动态规划 算法讲解066【必备】从递归入手一维动态规划 code1 509斐波那契数列 // 斐波那契数 // 斐波那契数 (通常用 F(n) 表示)形成的序列称为 斐波那契数列 // 该数列由 0 和 1 开始,后面的每一项数字都是前面两项数字的和。…

kotlin - ViewBinding

前言 为什么用ViewBinding,而不用findViewById(),这个有很多优秀的博主都做了讲解,就不再列出了。 可参考下列博主的文章: kotlin ViewBinding的使用 文章里也给出了如何在gradle中做出相应的配置。 (我建议先看这位博…

【LeetCode热题100】【滑动窗口】无重复字符的最长子串

给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。示例 2: 输入: s "bbbbb" 输出: 1 解释: 因为无…

Docker安装教程

docker官网 1.卸载旧版 yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \docker-engine2.配置Docker的yum库 安装yum工具 yum install -y yum-utils配置Docker的yum源 yum-config-ma…

Redis,什么是缓存穿透?怎么解决?

Redis,什么是缓存穿透?怎么解决? 1、缓存穿透 一般的缓存系统,都是按照key去缓存查询,如果不存在对用的value,就应该去后端系统查找(比如DB数据库)。一些恶意的请求会故意查询不存在…

不想写大量 if 判断?试试用规则执行器优化,就很丝滑!

近日在公司领到一个小需求,需要对之前已有的试用用户申请规则进行拓展。我们的场景大概如下所示: if (是否海外用户) {return false; }if (刷单用户) {return false; }if (未付费用户 && 不再服务时段) {return false }if (转介绍用户 || 付费用户 || 内推…

16ASM 分段和机器码

8086CPU存储分段管理 问题1:8086是16位cpu,最多可访问(寻址)多大内存? 运算器一次最多处理16位的数据。地址寄存器的最大宽度为16位。访问的最大内存为:216 64K 即 0000 - FFFF。 问题2:808…

Hadoop集群破坏试验可靠性验证

集群环境说明: 准备5台服务器,hadoop1、hadoop2、hadoop3、hadoop4、hadoop5; 分别部署5个节点的zookeeper集群、hadoop集群、hbase集群 本次对于Hadoop集群测试主要分为五个方面: 手动进行datanode节点删除:&#…

typedef 与#define 的区别

typedef 与#define 的区别 typedef : 给一个已经存在的数据类型(注意:是类型不是变量)取一个别名,而非定义一个新的数据类型 #define宏定义: #define宏定义:在预编译时直接进行简单的文本替换 举…

WIFI直连(Wi-Fi P2P)

一、概述 Wifi peer-to-peer(也称Wifi-Direct)是Wifi联盟推出的一项基于原来WIfi技术的可以让设备与设备间直接连接的技术,使用户不需要借助局域网或者AP(Access Point)就可以进行一对一或一对多通信。这种技术的应用…

计算机毕业设计 SpringBoot的乐乐农产品销售系统 Javaweb项目 Java实战项目 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

Xmanager

什么是 XManager Xmanager 是市场上领先的 PC X 服务器,可将X应用程序的强大功能带入 Windows 环境。 提供了强大的会话管理控制台,易于使用的 X 应用程序启动器,X 服务器配置文件管理工具,SSH 模块和高性能 PC X 服务器。 Xman…

javaScript(六):DOM操作

文章目录 1、DOM介绍2、DOM:获取Element对象3、DOM:事件监听3.1、事件介绍3.2、常见事件3.3、设置事件的两种方式3.4、事件案例 1、DOM介绍 概念 Document Object Model ,文档对象模型 将标记语言的各个组成部分封装为对应的对象&#xff1a…

Realme X7 Pro Root 刷机教程

Realme X7 Pro 刷机教程 Just For Fun,最近倒腾了下Realme X7 Pro 刷root。此博客为个人记录刷机过程,如有机友跟随本教程操作,请谨慎操作!!! 以下教程真针对Realme X7 Pro,其他版本方法未知&…

springboot(ssm高校竞赛管理系统 在线竞赛平台 Java系统

springboot(ssm高校竞赛管理系统 在线竞赛平台 Java系统 开发语言:Java 框架:ssm/springboot vue JDK版本:JDK1.8(或11) 服务器:tomcat 数据库:mysql 5.7(或8.0) 数…

qt 模型视图结构

在Qt中,Model、View和Delegate三者之间的关系如下: Model(模型):Model是数据的抽象表示,它提供了一种结构化的方式来存储和管理数据。Model负责维护数据的状态,并提供接口供其他组件&#xff08…