【DQN】基于pytorch的强化学习算法Demo

目录

  • 简介
  • 代码

简介

DQN(Deep Q-Network)是一种基于深度神经网络的强化学习算法,于2013年由DeepMind提出。它的目标是解决具有离散动作空间的强化学习问题,并在多个任务中取得了令人瞩目的表现。

DQN的核心思想是使用深度神经网络来逼近状态-动作值函数(Q函数),将当前状态作为输入,输出每个可能动作的Q值估计。通过不断迭代和更新网络参数,DQN能够逐步学习到最优的Q函数,并根据Q值选择具有最大潜在回报的动作。

DQN的训练过程中采用了两个关键技术:经验回放和目标网络。经验回放是一种存储并重复使用智能体经历的经验的方法,它可以破坏数据之间的相关性,提高训练的稳定性。目标网络用于解决训练过程中的估计器冲突问题,通过固定一个与训练网络参数较为独立的目标网络来提供稳定的目标Q值,从而减少训练的不稳定性。

DQN还采用了一种策略称为epsilon-贪心策略来在探索和利用之间进行权衡。初始时,智能体以较高的概率选择随机动作(探索),随着训练的进行,该概率逐渐降低,让智能体更多地依靠Q值选择最佳动作(利用)。

DQN在许多复杂任务中取得了显著的成果,特别是在Atari游戏等需要视觉输入的任务中。它的成功在很大程度上得益于深度神经网络的强大拟合能力和经验回放的效果,使得智能体能够通过与环境的交互进行自主学习。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01                   # learning rate
EPSILON = 0.9               # greedy policy
GAMMA = 0.9                 # reward discount
TARGET_REPLACE_ITER = 100   # target update frequency
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v1',render_mode="human")
#env = gym.make('CartPole-v0')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape     # to confirm the shapeclass Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1)   # initializationself.out = nn.Linear(50, N_ACTIONS)self.out.weight.data.normal_(0, 0.1)   # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_valueclass DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0                                     # for target updatingself.memory_counter = 0                                         # for storing memoryself.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memoryself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)self.loss_func = nn.MSELoss()def choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)# input only one sampleif np.random.uniform() < EPSILON:   # greedyactions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)  # return the argmax indexelse:   # randomaction = np.random.randint(0, N_ACTIONS)action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# replace the old memory with new memoryindex = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1def learn(self):# target parameter updateif self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter += 1# sample batch transitionssample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])# q_eval w.r.t the action in experienceq_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagateq_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)loss = self.loss_func(q_eval, q_target)self.optimizer.zero_grad()loss.backward()self.optimizer.step()dqn = DQN()  # 创建 DQN 对象print('\nCollecting experience...')
for i_episode in range(400):  # 进行 400 个回合的训练s, info = env.reset()  # 环境重置,获取初始状态 s 和其他信息ep_r = 0  # 初始化本回合的总奖励 ep_r 为 0while True:env.render()  # 显示环境,通过调用 render() 方法,可以将当前环境的状态以图形化的方式呈现出来.a = dqn.choose_action(s)  # 根据当前状态选择动作 a# 下一个状态(nextstate):返回智能体执行动作a后环境的下一个状态。在示例中,它存储在变量s_中。奖励(reward):返回智能体执行动作a后在环境中获得的奖励。在示例中,它存储在变中。# 完成标志(doneflag):返回一个布尔值,指示智能体是否已经完成了当前环境。在示例中,它存储在变量done中。# 截断标志(truncatedflag):返回一个布尔值,表示当前状态是否是由于达到了最大时间步骤或其他特定条件而被截断。在示例中,它存储在变量truncated中。# 其他信息(info):返回一个包含其他辅助信息的字典或对象。在示例中,它存储在变量info中。# 执行动作,获取下一个状态 s_,奖励 r,done 标志位,以及其他信息s_, r, done, truncated, info = env.step(a)# 修改奖励值#根据智能体在x方向和theta方向上与目标位置的偏离程度,计算两个奖励值r1和r2。具体计算方法是将每个偏离程度除以相应的阈值,然后减去一个常数(0.8和0.5)得到奖励值。这样,如果智能体在这两个方向上的偏离程度越小,奖励值越高。x, x_dot, theta, theta_dot = s_  # 从 s_ 中提取参数r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8  # 根据 x 的偏离程度计算奖励 r1r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5  # 根据 theta 的偏离程度计算奖励 r2r = r1 + r2  # 组合两个奖励成为最终的奖励 rdqn.store_transition(s, a, r, s_)  # 存储状态转换信息到经验池ep_r += r  # 更新本回合的总奖励if dqn.memory_counter > MEMORY_CAPACITY:  # 当经验池中的样本数量超过阈值 MEMORY_CAPACITY 时进行学习dqn.learn()if done:  # 如果本回合结束print('Ep: ', i_episode,'| Ep_r: ', round(ep_r, 2))  # 打印本回合的回合数和总奖励if done:  # 如果任务结束break  # 跳出当前回合的循环s = s_  # 更新状态,准备进行下一步动作选择

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/165866.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业数字化转型的作用是什么?_光点科技

在当今快速变化的商业环境中&#xff0c;数字化转型已成为企业发展的重要策略。企业数字化转型指的是利用数字技术改造传统业务模式和管理方式&#xff0c;以提升效率、增强竞争力和创造新的增长机会。 提升运营效率&#xff1a;数字化转型通过引入自动化工具和智能系统&#x…

指数退避重试

指数退避重试&#xff08;Exponential Backoff and Retry&#xff09;是一种网络通信中常用的错误处理和重试策略。它通常用于处理临时性的故障&#xff0c;例如网络延迟、服务器过载或临时性的错误&#xff0c;以提高系统的可靠性和稳定性。 基本思想是&#xff0c;当发生一个…

NX二次开发UF_CSYS_ask_wcs 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CSYS_ask_wcs Defined in: uf_csys.h int UF_CSYS_ask_wcs(tag_t * wcs_id ) overview 概述 Gets the object identifier of the coordinate system to which the work coordin…

JMeter压测常见面试问题

1、JMeter可以模拟哪些类型的负载&#xff1f; JMeter可以模拟各种类型的负载&#xff0c;包括但不限于Web应用程序、API、数据库、FTP、SMTP、JMS、SOAP / RESTful Web服务等。这使得JMeter成为一个功能强大且灵活的压力测试工具。 2、如何配置JMeter来进行分布式压力测试&a…

在华为昇腾开发板安装gdal-python

作者:朱金灿 来源:clever101的专栏 为什么大多数人学不会人工智能编程?>>> 在华为昇腾开发板安装gdal-python分为两步:编译gdal库和下载gdal对应的python包。 1.编译gdal库 首先下载gdal库,。在linux(arm架构)上编译的gdal库及其第三方库源码,内含一个编译…

智慧法院 | RPA+AI打造智慧执行助手,解决“案多人少”现实难题

为深化政法智能化建设&#xff0c;加强“智慧治理”“智慧法院”“智慧检务”“智慧警务”“智慧司法”等信息平台建设&#xff0c;深入实施大数据战略&#xff0c;实现科技创新成果同政法工作深度融合。法制日报社于今年3月继续举办了2023政法智能化建设创新案例及论文征集宣传…

Unity UGUI的HorizontalLayoutGroup(水平布局)组件

Horizontal Layout Group | Unity UI | 1.0.0 1. 什么是HorizontalLayoutGroup组件&#xff1f; HorizontalLayoutGroup是Unity UGUI中的一种布局组件&#xff0c;用于在水平方向上对子物体进行排列和布局。它可以根据一定的规则自动调整子物体的位置和大小&#xff0c;使它…

Shell脚本:Linux Shell脚本学习指南(第二部分Shell编程)二

第二部分&#xff1a;Shell编程&#xff08;二&#xff09; 十一、Shell数组&#xff1a;Shell数组定义以及获取数组元素 和其他编程语言一样&#xff0c;Shell 也支持数组。数组&#xff08;Array&#xff09;是若干数据的集合&#xff0c;其中的每一份数据都称为元素&#…

Navicat 技术指引 | GaussDB服务器对象的创建/设计(编辑)

Navicat Premium&#xff08;16.2.8 Windows版或以上&#xff09; 已支持对GaussDB 主备版的管理和开发功能。它不仅具备轻松、便捷的可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结构同步、协同合作、数据迁移等&#xff09;&#xff0c;这…

【华为OD题库-034】字符串化繁为简-java

题目 给定一个输入字符串&#xff0c;字符串只可能由英文字母(a ~ z、A ~ Z)和左右小括号()组成。当字符里存在小括号时&#xff0c;小括号是成对的&#xff0c;可以有一个或多个小括号对&#xff0c;小括号对不会嵌套&#xff0c;小括号对内可以包含1个或多个英文字母也可以不…

Jenkins Ansible 参数构建

首先在Jenkins中创建自由项目 在web端配置完成后在另一台机子上下载nginx 在gitlab端创建项目并创建文件配置代码 在有Jenkins的机器上下载Ansible [rootslave1 ~]# yum -y install epel-release [rootslave1 ~]# yum -y install ansible再进入下载nginx机器中克隆gitlab项目…

Android 框架层AIDL 添加接口

文章目录 AIDL的原理构建AIDL的流程往冻结的AIDL中加接口 AIDL的原理 可以利用ALDL定义客户端与服务均认可的编程接口&#xff0c;以便二者使用进程间通信 (IPC) 进行相互通信。在 Android 中&#xff0c;一个进程通常无法访问另一个进程的内存。因此&#xff0c;为进行通信&a…

卷积神经网络(AlexNet)鸟类识别

文章目录 一、前言二、前期工作1. 设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09;2. 导入数据3. 查看数据 二、数据预处理1. 加载数据2. 可视化数据3. 再次检查数据4. 配置数据集 三、AlexNet (8层&#xff09;介绍四、构建AlexNet (8层&#xff09;网络模型五、…

微信小程序image组件图片设置最大宽度 宽高自适应

问题描述&#xff1a;在使用微信小程序image组件的时候&#xff0c;在不确定图片宽高情况下 想给一个最大宽度让图片自适应&#xff0c;按比例&#xff0c;image的widthfiex和heightFiex并不能满足&#xff08;只指定最大宽/高并不会生效&#xff09; 问题解决&#xff1a;使用…

居家适老化设计第二十九条---卫生间之花洒

无电源 灯光显示 无障碍扶手型花洒 以上产品图片均来源于淘宝 侵权联系删除 居家适老化卫生间的花洒通常具有以下特点和功能&#xff1a;1. 高度可调节&#xff1a;适老化卫生间花洒可通过调节高度&#xff0c;满足不同身高的老年人使用需求&#xff0c;避免弯腰或过高伸展造…

【开源】基于Vue.js的固始鹅块销售系统

项目编号&#xff1a; S 060 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S060&#xff0c;文末获取源码。} 项目编号&#xff1a;S060&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 鹅块类型模块2.3 固…

qgis添加xyz栅格瓦片

方式1&#xff1a;手动一个个添加 左侧浏览器-XYZ Tiles-右键-新建连接 例如添加高德瓦片地址 https://wprd01.is.autonavi.com/appmaptile?langzh_cn&size1&style7&x{x}&y{y}&z{z} 双击即可呈现 收集到的一些图源&#xff0c;仅供参考&#xff0c;其中一…

【C++学习手札】模拟实现list

​ &#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;リナリア—まるりとりゅうが 0:36━━━━━━️&#x1f49f;──────── 3:51 &#x1f504; ◀️ ⏸ ▶️…

聊聊httpclient的staleConnectionCheckEnabled

序 本文主要研究一下httpclient的staleConnectionCheckEnabled staleConnectionCheckEnabled org/apache/http/client/config/RequestConfig.java public class RequestConfig implements Cloneable {public static final RequestConfig DEFAULT new Builder().build();pr…

【ARM 嵌入式 编译 Makefile 系列 18 -- Makefile 中的 export 命令详细介绍】

文章目录 Makefile 中的 export 命令详细介绍Makefile 使用 export导出与未导出变量的区别示例&#xff1a;导出变量以供子 Makefile 使用 Makefile 中的 export 命令详细介绍 在 Makefile 中&#xff0c;export 命令用于将变量从 Makefile 导出到由 Makefile 启动的子进程的环…