pytorch强化学习(1)——DQNSARSA

实验环境

python=3.10
torch=2.1.1
gym=0.26.2
gym[classic_control]
matplotlib=3.8.0
numpy=1.26.2

DQN代码

首先是module.py代码,在这里定义了网络模型和DQN模型

import torch
import torch.nn as nn
import numpy as npclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x):  # [b,n_states]return self.network(x)class DQN:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)if np.random.random() < self.epsilon:action_values = self.q_net(state)  # q_net(state)采取动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]# 前千万不能缺少done,如果下一步游戏结束的花,那下一步的q值应该为0q_target = reward + self.gamma * self.q_net(next_state).max() * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()

然后是train.py代码,在这里调用DQN模型和gym环境,来进行训练:

import gym
import torch
from module import DQN
import matplotlib.pyplot as pltlr = 1e-3  # 学习率
gamma = 0.95  # 折扣因子
epsilon = 0.8  # 贪心系数
n_hidden = 200  # 隐含层神经元个数env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)if __name__ == '__main__':reward_list = []for i in range(500):state = env.reset()[0]  # len=4total_reward = 0done = Falsewhile True:# 获取当前状态下需要采取的动作action = dqn.choose_action(state)# 更新环境next_state, reward, done, _, _ = env.step(action)dqn.update(state, action, reward, next_state, done)state = next_statetotal_reward += rewardif done:breakprint("第%d回合,total_reward=%f" % (i, total_reward))reward_list.append(total_reward)# 绘图episodes_list = list(range(len(reward_list)))plt.plot(episodes_list, reward_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN Returns')plt.show()

SARSA代码

首先是module.py代码,在这里定义了网络模型和SARSA模型。
SARSA和DQN基本相同,只有在更新Q网络的时候略有不同,已在代码相应位置做出注释。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Fclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x):  # [b,n_states]return self.network(x)class SARSA:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)# 基于贪婪系数,有一定概率采取随机策略if np.random.random() < self.epsilon:action_values = self.q_net(state)  # q_net(state)是在当前状态采取各个动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]'''sarsa在更新网络时选择的是q_net(next_state)[next_action] 这是sarsa算法和dqn的唯一不同dqn是选择max(q_net(next))'''next_action = self.choose_action(next_state)# 千万不能缺少done,如果下一步游戏结束的话,那下一步的q值应该为0,而不是q网络输出的值q_target = reward + self.gamma * self.q_net(next_state)[next_action] * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()

SARSA也有tarin.py文件,功能和上面DQN的一样,内容也几乎完全一样,只是把DQN的名字改成SARSA而已,所以在这里不再赘述。

运行结果

DQN的运行结果如下:
在这里插入图片描述

SARSA运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LLM大语言模型(二):Streamlit 无需前端经验也能画web页面

目录 问题 Streamlit是什么&#xff1f; 怎样用Streamlit画一个LLM的web页面呢&#xff1f; 文本输出 页面布局 滑动条 按钮 对话框 输入框 总结 问题 假如你是一位后端开发&#xff0c;没有任何的web开发经验&#xff0c;那如何去实现一个LLM的对话交互页面呢&…

Python MySQL数据库连接与基本使用

一、应用场景 python项目连接MySQL数据库时&#xff0c;需要第三方库的支持。这篇文章使用的是PyMySQL库&#xff0c;适用于python3.x。 二、安装 pip install PyMySQL三、使用方法 导入模块 import pymysql连接数据库 db pymysql.connect(hostlocalhost,usercode_space…

Spring MVC开发流程

1.Spring MVC环境基本配置 Maven工程依赖spring-webmvc <dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</artifactId><version>5.1.9.RELEASE</version> </dependency>web.xml配置Dispatche…

NSSCTF第16页(2)

[NSSRound#4 SWPU]1zweb(revenge) 查看index.php <?php class LoveNss{public $ljt;public $dky;public $cmd;public function __construct(){$this->ljt"ljt";$this->dky"dky";phpinfo();}public function __destruct(){if($this->ljt"…

day01unittest复习,断言

1.unittest 方法执行前 # def setUp(self) -> None: # print(方法执行前执行) # # def tearDown(self) -> None: # print(方法执行后执行一次) 2.unittest 类方法执行前后执行一次 classmethod def setUpClass(cls) -> None:print(类执行前执行一次)classm…

微分和导数(一)

1.微分&#xff1a; 假设我们有⼀个函数f : R → R&#xff0c;其输⼊和输出都是标量。如果f的导数存在&#xff0c;这个极限被定义为 如果f′(a)存在&#xff0c;则称f在a处是可微的。如果f在⼀个区间内的每个数上都是可微的&#xff0c;则此函数在此区间中是可微的。导数f′…

网络协议 - UDP 协议详解

网络协议 - UDP 协议详解 UDP概述UDP特点UDP的首部格式UDP校验 參考文章 基于TCP和UDP的协议非常广泛&#xff0c;所以也有必要对UDP协议进行详解。 UDP概述 UDP(User Datagram Protocol)即用户数据报协议&#xff0c;在网络中它与TCP协议一样用于处理数据包&#xff0c;是一种…

必要时进行保护性拷贝

保护性拷贝&#xff08;Defensive Copy&#xff09;是一种常见的编程实践&#xff0c;用于在传递参数或返回值时&#xff0c;创建副本以防止原始对象被意外修改。以下是一个例子&#xff0c;展示了何时进行保护性拷贝&#xff1a; mport java.util.ArrayList; import java.uti…

数据手册Datasheet解读-肖特基二极管笔记

数据手册Datasheet解读笔记1-肖特基二极管 数据手册大体结构共包含10个部分肖特基二极管-SS14第一重点关注点&#xff1a;极限值第二重点关注点&#xff1a;电气特性 数据手册大体结构共包含10个部分 1.Features一特性 2.Application一应用 3.Description一说明4.Pin Configur…

关于在Java中打印“数字”三角形图形的汇总

之前写过一篇利用*打印三角形汇总&#xff0c;网友需要查看可以去本专栏查找之前的文章&#xff0c;这里利用二维数组嵌套循环打印“数字”三角形&#xff0c;汇总如下&#xff0c;话不多说&#xff0c;直接上代码&#xff1a; /*** 打印如下数字三角形图形*/ public class Wo…

逻辑分析仪_使用手册

LA1010 1> 能干啥&#xff1f;2> 硬件连接3> 软件安装4> 参数设置4.1> 采样深度和采样率4.2> 添加协议解析器4.3> 毛刺过滤设置 1> 能干啥&#xff1f; 测量通信波形&#xff0c;并自动解析&#xff1b; 比如测量&#xff0c;UART&#xff0c;SPI&…

【DataSophon】大数据管理平台DataSophon-1.2.1安装部署详细流程

&#x1f984; 个人主页——&#x1f390;开着拖拉机回家_Linux,大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&am…

java_web_电商项目

java_web_电商项目 1.登录界面2.注册界面3. 主界面4.分页界面5.商品详情界面6. 购物车界面7.确认订单界面8.个人中心界面9.收货地址界面10.用户信息界面11.用户余额充值界面12.后台首页13.后台商品增加14.后台用户增加15.用户管理16.源码分享1.登录页面的源码2.我们的主界面 1.…

在线二进制原码,补码,反码计算器

具体请前往&#xff1a;在线原码/反码/补码计算器

LLM中的Prompt提示

简介 在LLM中&#xff0c;prompt&#xff08;提示&#xff09;是一个预先设定的条件&#xff0c;它可以限制模型自由发散&#xff0c;而是围绕提示内容进行展开。输入中添加prompt&#xff0c;可以强制模型关注特定的信息&#xff0c;从而提高模型在特定任务上的表现。 结构 …

会声会影怎么使用? 会声会影2024快速掌握入门技巧

一听说视频剪辑我们就不由得联想到电影、电视等一些高端的视频剪辑技术&#xff0c;大家都觉得视频剪辑是一个非常复杂而且需要很昂贵的设备才可以完成的技术活&#xff0c;这对很多“门外汉”来说都可望而不可及。实际上&#xff0c;使用会声会影剪辑视频不仅是很多人都可以操…

【深度强化学习】策略梯度方法:REINFORCE、Actor-Critic

参考 Reinforcement Learning, Second Edition An Introduction By Richard S. Sutton and Andrew G. Barto非策略梯度方法的问题 之前的算法&#xff0c;无论是 MC&#xff0c;TD&#xff0c;SARSA&#xff0c;Q-learning&#xff0c; 还是 DQN、Double DQN、Dueling DQN…

STM32G030C8T6:使用按键控制LED亮灭(外部中断)

本专栏记录STM32开发各个功能的详细过程&#xff0c;方便自己后续查看&#xff0c;当然也供正在入门STM32单片机的兄弟们参考&#xff1b; 本小节的目标是&#xff0c;系统主频64 MHZ,采用高速外部晶振&#xff0c;通过KEY1 按键的PA0 引脚配置成中断输入引脚&#xff0c;PB9引…

写好ChatGPT提示词原则之:清晰且具体(clear specific)

ChatGPT 的优势在于它允许用户跨越机器学习和深度学习的复杂门槛&#xff0c;直接利用已经训练好的模型。然而&#xff0c;即便是这些先进的大型语言模型也面临着上下文理解和模型固有局限性的挑战。为了最大化这些大型语言模型&#xff08;LLM&#xff09;的潜力&#xff0c;关…

Spring 6(二)【IOC原理】

前言 1、IOC IoC 是 Inversion of Control 的简写&#xff0c;译为“控制反转”&#xff0c;它不是一门技术&#xff0c;而是一种设计思想&#xff0c;是一个重要的面向对象编程法则&#xff0c;能够指导我们如何设计出松耦合、更优良的程序。 1.1、控制反转 控制反转不是技术…