Noisy DQN 跑 CartPole-v1

gym 0.26.1
CartPole-v1
NoisyNet DQN

NoisyNet 就是把原来Linear里的w/b 换成 mu + sigma * epsilon, 这是一种非常简单的方法,但是可以显著提升DQN的表现。
和之前最原始的DQN相比就是改了两个地方,一个是Linear改成了NoisyLinear,另外一个是在agenttake_action的时候策略 由ε-greedy改成了直接取argmax。详细见下面的代码。

本文的实现参考王树森的深度强化学习。

引用书上的一段话, 噪声DQN本身就带有随机性,可以鼓励探索,起到与ε-greedy策略相同的作用,直接用a_t = argmax Q(s,a,epsilon; mu,sigma), 作为行为策略,效果比ε-greedy更好。

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
import mathclass ReplayBuffer:"""经验回放池"""def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出def add(self, state, action, reward, next_state, done): # 将数据加入bufferself.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size): # 从buffer中采样数据,数量为batch_sizetransition = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transition)return np.array(state), action, reward, np.array(next_state), donedef size(self): # 目前buffer中数据的数量return len(self.buffer)class NoisyLinear(nn.Linear):def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):super().__init__(in_features, out_features, bias)self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))if bias:self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))self.register_buffer("epsilon_bias", torch.zeros(out_features))self.reset_parameters()def reset_parameters(self):std = math.sqrt(3 / self.in_features)self.weight.data.uniform_(-std, std)self.bias.data.uniform_(-std, std)def forward(self, x, is_training=True):self.epsilon_weight.normal_()bias = self.biasif bias is not None:self.epsilon_bias.normal_()bias = bias + self.sigma_bias * self.epsilon_bias.dataif is_training:return F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)else:return F.linear(x, self.weight, bias)class Q(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = NoisyLinear(state_dim, hidden_dim)self.fc2 = NoisyLinear(hidden_dim, action_dim)def forward(self, x, is_training=True):x = F.relu(self.fc1(x, is_training)) # 隐藏层之后使用ReLU激活函数return self.fc2(x, is_training)class DQN:"""DQN算法"""def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, target_update, device):self.action_dim = action_dimself.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络self.target_q.load_state_dict(self.q.state_dict())  # 加载参数self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)self.gamma = gammaself.target_update = target_update # 目标网络更新频率self.count = 0 # 计数器,记录更新次数self.device = devicedef take_action(self, state): # 这个地方就不用epsilon-贪婪策略state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action = self.q(state).argmax().item()return actiondef update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)q_values = self.q(states).gather(1, actions) # Q值# 下个状态的最大Q值max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差loss = F.mse_loss(q_values, q_targets) # 均方误差self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加loss.mean().backward() # 反向传播self.optimizer.step() # 更新梯度if self.count % self.target_update == 0:self.target_q.load_state_dict(self.q.state_dict())self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, target_update, device)
return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

这次是在pycharm上运行jupyter file,结果如下:




效果对比之前的DQN 详细参考这篇 表现是显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/609025.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 SpringBoot + vue 的医院管理系统(含源码,数据库,文档)

基于 SpringBoot vue 的医院管理系统 †前后端分离思想,这个系统简直太棒了!屯 光这个系统采用了 前后端分离思想,后端使用 SpringBoot和 SpringMVC框架,让代码更高效,更易于维护。前端则使用了 vue js 和ElementU…

mybatis plus相同Id与xml配置错误时,mybatis plus解决逻辑

前言 处理做项目的问题,其中不乏奇奇怪怪的问题,其中mybatis plus的问题感觉有点隐蔽,有些是运行时出现,有些是运行到具体的逻辑触发,对于应用的状态监控提出了极大的挑战,应用的状态由健康检查接口提供&a…

Vue_00001_CLI

初始化脚手架 初始化脚手架步骤: 第一步(仅第一次执行):全局安装vue/cli。 命令:npm install -g vue/cli 第二步:切换到要创建项目的目录,然后使用命令创建项目。 命令:vue creat…

python在一个文件夹中按文件名排序,获取这一文件的下一个文件名

在Python中,您可以使用sorted()函数对文件夹内的文件进行排序,并使用os模块来遍历这些文件。获取特定文件的下一个文件名可以通过排序后的列表来查找。下面是一个简单的示例代码,展示了如何对一个文件夹中的文件按文件名进行排序,…

【基础工具篇使用】Windows环境下瑞芯微开发工具的安装和使用

文章目录 Rockchip 烧录驱动的安装Rockchip 烧录工具使用导入配置MASKROM 模式烧录LOADER 模式烧录Update.img 包的烧录 Rockchip 烧录驱动的安装 瑞芯微提供了 RKDevTool 上位机烧录工具,此工具只能在 Windows 系统下运行,运行前要先安装驱动文件 Ro…

mitmproxy代理抓包使用mock数据

第一步 安装Python环境 下载Python环境安装包https://www.python.org/getit/https://link.jianshu.com/?thttps%3A%2F%2Fwww.python.org%2Fgetit%2F (图a) 安装Python的时候勾选“Add Python 3.5 to PATH”选项(图a) 打开CMD命…

Pytest接口自动化测试框架搭建

一. 背景 Pytest目前已经成为Python系自动化测试必学必备的一个框架,网上也有很多的文章讲述相关的知识。最近自己也抽时间梳理了一份pytest接口自动化测试框架,因此准备写文章记录一下,做到尽量简单通俗易懂,当然前提是基本的py…

书生·浦语大模型实战营第二次课堂笔记

文章目录 什么是大模型?pip,conda换源模型下载 什么是大模型? 人工智能领域中参数数量巨大、拥有庞大计算能力和参数规模的模型 特点及应用: 利用大量数据进行训练拥有数十亿甚至数千亿个参数模型在各种任务重展现出惊人的性能 …

数据结构入门到入土——链表(完)LinkedList

目录 一,双向链表 1.单向链表的缺点 2.什么是双向链表? 3.自主实现双向链表 接口实现: 二,LinkedList 1.LinkedList的使用 1.1 什么是LinkedList? 1.2 LinkedList的使用 1.LinkedList的构造 2.LinkedList的…

elementUI点击el-card选中变色,且点击别的空白处不变色

1. <script>的data中添加属性&#xff1a; selectIndex: 0 2.<template>中添加el-card元素&#xff1a; click.native调用原生click方法。 click.native是在vue中&#xff0c;避免vue父模块调用成了vue成子模块中的this.emit(click, value)的方法&#xff0c;而…

Pruning Papers

[ICML 2020] Rigging the Lottery: Making All Tickets Winners 整个训练过程中mask是动态的&#xff0c;有drop和grow两步&#xff0c;drop是根据权重绝对值的大小丢弃&#xff0c;grow是根据剩下激活的权重中梯度绝对值生长没有先prune再finetune/retrain的两阶段过程 Laye…

C++-nullptr-类型推导

1、nullptr&#xff08;掌握&#xff09;&#xff08;NULL 就是0&#xff09; NULL 在源码当中就是0&#xff0c;因此可能会存在一些二义性的问题。 #include <iostream> #include <memory> using namespace std;void func(int a) {cout << "a " …

[原创][R语言]股票分析实战[9]:周内第N天转换为星期N因子

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XX QQ联系: 643439947 个人网站: 80x86汇编小站 https://www.x86asm.org 编程生涯: 2001年~至今[共22年] 职业生涯: 20年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、D…

工业异常检测AnomalyGPT-Demo试跑

写在前面&#xff1a;如果你有大的cpu和gpu可以使用&#xff0c;直接根据官方的安装说明就可以&#xff0c;如果没有&#xff0c;可以点进来试着看一下我个人的安装经验。 一、试跑环境 NVIDIA4090显卡24g,cpu内存33G&#xff0c;交换空间8g,操作系统ubuntu22.04(试跑过程cpu…

问题 F: 分巧克力

题目描述 儿童节那天有 K 位小朋友到小明家做客。小明拿出了珍藏的巧克力招待小朋友们。小明一共有 N 块巧克力&#xff0c;其中第i 块HiWi 的方格组成的长方形。 为了公平起见&#xff0c;小明需要从这 N 块巧克力中切出 K 块巧克力分给小朋友们。 切出的巧克力需要满足&am…

SEO写作:撰写在Google上排名的博客文章的13个技巧

随着排名的提高&#xff0c;您的网站可以提高其整体知名度。最终目标是通过有效的优化来推动自然流量&#xff0c;增加转化率&#xff0c;并实现业务目标。 如果你不针对搜索引擎优化你的内容&#xff0c;你的网站可能会在搜索引擎结果页面&#xff08;SERP&#xff09;上出现…

第7章-第9节-Java中的Stream流(链式调用)

1、什么是Stream流 Lambda表达式&#xff0c;基于Lambda所带来的函数式编程&#xff0c;又引入了一个全新的Stream概念&#xff0c;用于解决集合类库既有的鼻端。 2、案例 假设现在有一个需求&#xff0c; 将list集合中姓张的元素过滤到一个新的集合中&#xff1b;然后将过滤…

Realm Management Extension领域管理扩展简介

本博客介绍了领域管理扩展(RME),这是Arm的架构扩展。RME是Arm机密计算架构(Arm CCA)的硬件组件,同时包括软件元素。RME动态地将资源和内存转移到新的受保护的地址空间,高特权软件或TrustZone固件无法访问。由于存在这个地址空间,Arm CCA构建了受保护的执行环境,称为领…

详解Oracle数据库的启动

Oracle数据库的启动&#xff0c;其概念可参考Overview of Instance and Database Startup。 其过程可参见下图&#xff1a; 当数据库从关闭状态进入打开数据库状态时&#xff0c;它会经历以下阶段。 阶段Mount状态描述1实例在没有挂载数据库的情况下启动实例已启动&#xff…

Numerical calculation and its application based on NumPy/SciPy

Numerical calculation and its application based on NumPy/SciPy 线性代表微分和积分统计插值 线性代表 { 3 x 1 2 x 2 8 − 3 x 1 5 x 2 − 1 \begin{cases} \begin{equation} \begin{split} 3x_12x_2 8\\ -3x_15x_2 -1 \end{split} \end{equation} \end{cases} {3x1​…