reinforce 跑 CartPole-v1

gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。

然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。

代码

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根据动作概率分布随机采样state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):  # 公式用的是简化推导reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))):  # 从最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G  # 因为梯度更新是减的,所以取个负号loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :  # 主要是这部分和原始的有点不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,结果如下所示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/212418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 11 适配——整理总结篇

背景 > 经过检测,我们识别到您的应用,目前未适配安卓11(API30),请您关注适配截止时间,尽快开展适配工作,避免影响应用正常发布和经营。 > targetSdkVersion30 升级适配工作参考文档&am…

从零开发短视频电商 Jmeter压测示例模板详解(无认证场景)

文章目录 添加线程组添加定时器添加HTTP请求默认值添加HTTP头管理添加HTTP请求添加结果断言响应断言 Response AssertionJSON断言 JSON Assertion持续时间断言 Duration Assertion 添加察看结果树添加聚合报告添加表格察看结果参考 以压测百度搜索为例 https://www.baidu.com/s…

class066 一维动态规划【算法】

class066 一维动态规划 算法讲解066【必备】从递归入手一维动态规划 code1 509斐波那契数列 // 斐波那契数 // 斐波那契数 (通常用 F(n) 表示)形成的序列称为 斐波那契数列 // 该数列由 0 和 1 开始,后面的每一项数字都是前面两项数字的和。…

kotlin - ViewBinding

前言 为什么用ViewBinding,而不用findViewById(),这个有很多优秀的博主都做了讲解,就不再列出了。 可参考下列博主的文章: kotlin ViewBinding的使用 文章里也给出了如何在gradle中做出相应的配置。 (我建议先看这位博…

【LeetCode热题100】【滑动窗口】无重复字符的最长子串

给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。示例 2: 输入: s "bbbbb" 输出: 1 解释: 因为无…

Redis,什么是缓存穿透?怎么解决?

Redis,什么是缓存穿透?怎么解决? 1、缓存穿透 一般的缓存系统,都是按照key去缓存查询,如果不存在对用的value,就应该去后端系统查找(比如DB数据库)。一些恶意的请求会故意查询不存在…

不想写大量 if 判断?试试用规则执行器优化,就很丝滑!

近日在公司领到一个小需求,需要对之前已有的试用用户申请规则进行拓展。我们的场景大概如下所示: if (是否海外用户) {return false; }if (刷单用户) {return false; }if (未付费用户 && 不再服务时段) {return false }if (转介绍用户 || 付费用户 || 内推…

16ASM 分段和机器码

8086CPU存储分段管理 问题1:8086是16位cpu,最多可访问(寻址)多大内存? 运算器一次最多处理16位的数据。地址寄存器的最大宽度为16位。访问的最大内存为:216 64K 即 0000 - FFFF。 问题2:808…

WIFI直连(Wi-Fi P2P)

一、概述 Wifi peer-to-peer(也称Wifi-Direct)是Wifi联盟推出的一项基于原来WIfi技术的可以让设备与设备间直接连接的技术,使用户不需要借助局域网或者AP(Access Point)就可以进行一对一或一对多通信。这种技术的应用…

计算机毕业设计 SpringBoot的乐乐农产品销售系统 Javaweb项目 Java实战项目 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

Xmanager

什么是 XManager Xmanager 是市场上领先的 PC X 服务器,可将X应用程序的强大功能带入 Windows 环境。 提供了强大的会话管理控制台,易于使用的 X 应用程序启动器,X 服务器配置文件管理工具,SSH 模块和高性能 PC X 服务器。 Xman…

javaScript(六):DOM操作

文章目录 1、DOM介绍2、DOM:获取Element对象3、DOM:事件监听3.1、事件介绍3.2、常见事件3.3、设置事件的两种方式3.4、事件案例 1、DOM介绍 概念 Document Object Model ,文档对象模型 将标记语言的各个组成部分封装为对应的对象&#xff1a…

Realme X7 Pro Root 刷机教程

Realme X7 Pro 刷机教程 Just For Fun,最近倒腾了下Realme X7 Pro 刷root。此博客为个人记录刷机过程,如有机友跟随本教程操作,请谨慎操作!!! 以下教程真针对Realme X7 Pro,其他版本方法未知&…

【Flutter】vs2022上开发flutter

在vs上开发flutter,结果扩展仓库上没办法找到Dart,Flutter。 在 这 搜索Dart时也无法找到插件。 最后发现是安装工具出错了 安装了 开发需要的是

从线性回归到神经网络

目录 一、线性回归关键思想 1、线性模型 2、基础优化算法 二、线性回归的从零开始实现 1、生成数据集 2、读取数据集 3、初始化模型参数 4、定义模型 5、定义损失函数 6、定义优化算法 7、训练 三、线性回归的简洁实现 1、生成数据集 2、读取数据集 3、定义模型…

论文代码阅读:TGN模型训练阶段代码理解

文章目录 [toc] TGN模型训练阶段代码理解论文信息代码过程手绘代码训练过程compute_temporal_embeddingsupdate_memoryget_raw_messagesget_updated_memoryself.message_aggregator.aggregateself.memory_updater.get_updated_memoryMemoryget_embedding_moduleGraphAttentionE…

【AIGC】Midjourney高级进阶版

Midjourney 真是越玩越上头,真是给它的想象力跪了~ 研究了官方API,出一个进阶版教程 命令 旨在介绍Midjourney在Discord频道中的文本框中支持的指令。 1)shorten 简化Prompt 该指令可以将输入的Prompt为模型可以理解的语言。模型理解语言…

【Linux】如何对文本文件进行有条件地划分?——cut命令

cut 命令可以根据一个指定的标记(默认是 tab)来为文本划分列,然后将此列显示。 例如想要显示 passwd 文件的第一列可以使用以下命令:cut –f 1 –d : /etc/passwd cut:用于从文件的每一行中提取部分内容的命令。-f 1&…

Sql server数据库数据查询

请查询学生信息表的所有记录。 答:查询所需的代码如下: USE 学生管理数据库 GO SELECT * FROM 学生信息表 执行结果如下: 查询学生的学号、姓名和性别。 答:查询所需的代码如下: USE 学生管理数据库 GO SELE…

为什么需要 Kubernetes,它能做什么?

传统部署时代: 早期,各个组织是在物理服务器上运行应用程序。 由于无法限制在物理服务器中运行的应用程序资源使用,因此会导致资源分配问题。 例如,如果在同一台物理服务器上运行多个应用程序, 则可能会出现一个应用程…