Pytorch深度强化学习2-1:基于价值的强化学习——DQN算法

目录

  • 0 专栏介绍
  • 1 基于价值的强化学习
  • 2 深度Q网络与Q-learning
  • 3 DQN原理分析
  • 4 DQN训练实例

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 基于价值的强化学习

根据不动点定理,最优策略和最优价值函数是唯一的(对该经典理论不熟悉的请看Pytorch深度强化学习1-4:策略改进定理与贝尔曼最优方程详细推导),通过优化价值函数间接计算最优策略的方法称为基于价值的强化学习(value-based)框架。设状态空间为 n n n维欧式空间 S = R n S=\mathbb{R} ^n S=Rn,每个维度代表状态的一个特征。此时状态-动作值函数记为

Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)

其中 s \boldsymbol{s} s是状态向量, a \boldsymbol{a} a是动作空间中的动作向量, θ \boldsymbol{\theta } θ是神经网络的参数向量。深度学习完成了从输入状态到输出状态-动作价值的映射

s → Q ( s , a ; θ ) [ Q ( s , a 1 ) Q ( s , a 2 ) ⋯ Q ( s , a m ) ] T ( a 1 , a 2 , ⋯ , a m ∈ A ) \boldsymbol{s}\xrightarrow{Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}\left[ \begin{matrix} Q\left( \boldsymbol{s},a_1 \right)& Q\left( \boldsymbol{s},a_2 \right)& \cdots& Q\left( \boldsymbol{s},a_m \right)\\\end{matrix} \right] ^T\,\, \left( a_1,a_2,\cdots ,a_m\in A \right) sQ(s,a;θ) [Q(s,a1)Q(s,a2)Q(s,am)]T(a1,a2,,amA)

相当于对无穷维Q-Table的一次隐式查表,对经典Q-learing算法不熟悉的请看Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)、Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫。设目标价值函数为 Q ∗ Q^* Q,若采用最小二乘误差,可得损失函数为

J ( θ ) = E [ 1 2 ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) 2 ] J\left( \boldsymbol{\theta } \right) =\mathbb{E} \left[ \frac{1}{2}\left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) ^2 \right] J(θ)=E[21(Q(s,a)Q(s,a;θ))2]

采用梯度下降得到参数更新公式为

θ ← θ + α ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) ∂ Q ( s , a ; θ ) ∂ θ \boldsymbol{\theta }\gets \boldsymbol{\theta }+\alpha \left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) \frac{\partial Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}{\partial \boldsymbol{\theta }} θθ+α(Q(s,a)Q(s,a;θ))θQ(s,a;θ)

随着迭代进行, Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)将不断逼近 Q ∗ Q^* Q,由 Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)进行的策略评估和策略改进也将迭代至最优。

2 深度Q网络与Q-learning

Q-learning和深度Q学习(Deep Q-learning, DQN)是强化学习领域中两种重要的算法,它们在解决智能体与环境之间的决策问题方面具有相似之处,但也存在一些显著的异同。这里进行简要阐述以加深对二者的理解。

  • Q-learning是一种基于值函数的强化学习算法。它通过使用Q-Table来表示每个状态和动作对的预期回报。Q值函数用于指导智能体在每个时间步选择最优动作。通过不断更新Q值函数来使其逼近最优的Q值函数
  • DQN是对Q-learning的深度网络版本,它将神经网络引入Q-learning中,以处理具有高维状态空间的问题。通过使用深度神经网络作为函数逼近器,DQN可以学习从原始输入数据(如像素值)直接预测每个动作的Q值

在这里插入图片描述

3 DQN原理分析

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池(Experience Replay):考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性;
  • 目标网络(Target Network):考虑到若目标价值 与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建与 结构相同但慢于 更新的独立目标网络来评估目标价值,使模型更稳定。

拟合了高维状态空间,是Q-Learning算法的深度学习版本,算法流程如表所示

在这里插入图片描述

4 DQN训练实例

最简单的例子是使用全连接网络来构造DQN

class DQN(nn.Module):def __init__(self, input_dim, output_dim):super(DQN, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.fc = nn.Sequential(nn.Linear(self.input_dim[0], 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, self.output_dim))def __str__(self) -> str:return "Fully Connected Deep Q-Value Network, DQN"def forward(self, state):qvals = self.fc(state)return qvals

基于贝尔曼最优原理的损失计算如下

def computeLoss(self, batch):states, actions, rewards, next_states, dones = batchstates = torch.FloatTensor(states).to(self.device)actions = torch.LongTensor(actions).to(self.device)rewards = torch.FloatTensor(rewards).to(self.device)next_states = torch.FloatTensor(next_states).to(self.device)dones = (1 - torch.FloatTensor(dones)).to(self.device)# 根据实际动作提取Q(s,a)值curr_Q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)next_Q = self.target_model(next_states)max_next_Q = torch.max(next_Q, 1)[0]expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q * donesloss = self.criterion(curr_Q, expected_Q.detach())return loss

基于经验回放池和目标网络的参数更新如下

def update(self, batch_size):batch = self.replay_buffer.sample(batch_size)loss = self.computeLoss(batch)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 更新target网络for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)# 退火self.epsilon = self.epsilon + self.epsilon_delta \if self.epsilon < self.epsilon_max else self.epsilon_max

基于DQN可以实现最基本的智能体,下面给出一些具体案例

  • Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/578371.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL server 数据库练习题及答案(练习3)

一、编程题 公司部门表 department 字段名称 数据类型 约束等 字段描述 id int 主键&#xff0c;自增 部门ID name varchar(32) 非空&#xff0c;唯一 部门名称 description varchar(1024) …

HTTP content-type内容类型的常见格式

本专栏是汇集了一些HTML常常被遗忘的知识&#xff0c;这里算是温故而知新&#xff0c;往往这些零碎的知识点&#xff0c;在你开发中能起到炸惊效果。我们每个人都没有过目不忘&#xff0c;过久不忘的本事&#xff0c;就让这一点点知识慢慢渗透你的脑海。 本专栏的风格是力求简洁…

分布式训练通信NCCL之Ring-Allreduce详解

&#x1f380;个人主页&#xff1a; https://zhangxiaoshu.blog.csdn.net &#x1f4e2;欢迎大家&#xff1a;关注&#x1f50d;点赞&#x1f44d;评论&#x1f4dd;收藏⭐️&#xff0c;如有错误敬请指正! &#x1f495;未来很长&#xff0c;值得我们全力奔赴更美好的生活&…

fpga 8段4位数码管verilator模拟

8段4位数码管verilator模拟 seg.v module seg(input wire clk,input wire rst_n,output wire[7:0] SEG,output wire[3:0] SEL );reg[7:0] digit[0:15] {8h3f, 8h06, 8h5b, 8h4f, 8h66, 8h6d, 8h7d,8h07,8h7f,8h6f, 8h77, 8h7c, 8h39, 8h5e, 8h79, 8h71};reg[31:0] cnt 32…

Opencv_CUDA实现推理图像前处理与后处理

Opencv_CUDA实现推理图像前处理与后处理 通过trt 或者 openvino部署深度学习算法时&#xff0c;往往会通过opencv的Mat及算法将图像转换为固定的格式作为输入openvino图像的前后处理后边将在单独的文章中写出今晚空闲搜了一些opencv_cuda的使用方法&#xff0c;在此总结一下前…

云服务器ECS运维管理

目录 实时掌握CPU、内存使用情况 实时掌握存储的使用情况 定期对云服务器数据做好备份 定期检查云服务器的安全运行情况 要想保证云服务器长期稳定的使用&#xff0c;除了依靠阿里云&#xff08;云服务提供商&#xff09;的技术支持&#xff0c;自身必要的安全维护手段也是…

W6100-EVB-Pico评估版介绍

文章目录 1 简介2 硬件资源2.1 硬件规格2.2 引脚定义2.3 工作条件 3 参考资料3.1 Datasheet3.2 原理图3.3 尺寸图&#xff08;尺寸&#xff1a;mm&#xff09;3.4 参考例程 4 硬件协议栈优势 1 简介 W6100-EVB-Pico是一款基于树莓派RP2040和全硬件TCP/IP协议栈以太网芯片W6100的…

ApiPost测试token验证端口(若依)

首先ApiPost自带默认环境与Mock环境。 接下来自己创建新环境设置变量。 注&#xff1a;若本地环境与生产环境端口不一致&#xff0c;这里的url也要带上端口号 创建一个本地环境&#xff0c;增加环境变量url&#xff0c;默认值为localhost。 再新建一个生产环境。 新建一个登…

Hadoop集群部署

目录 1 模板虚拟机环境准备 1.1 修改网卡配置文件 扩展 1.2 修改主机名 1.3 在虚拟机中需要的基础文件包 1.4 关闭防火墙 1.5 创建Hadoop的账户及文件 2 模板虚拟机安装JDK 3 模板虚拟机安装Hadoop 4 克隆虚拟机 5 虚拟机配置主机名称映射 6 集群分发脚本 7 SSH无…

HTML代码全解析

HTML代码全解析实例解析 <!DOCTYPE html> 声明为 HTML5 文档<html> 元素是 HTML 页面的根元素<head> 元素包含了文档的元&#xff08;meta&#xff09;数据&#xff0c;如 <meta charset"utf-8"> 定义网页编码格式为 utf-8。<title> 元…

1233. 全球变暖(bfs宽搜相邻点)

题目&#xff1a; 1233. 全球变暖 - AcWing题库 思路&#xff1a;bfs 1.临接问题&#xff0c;最短路径问题--->bfs。 2.被完全淹没--->岛屿所以部分均临海。 代码&#xff1a; #include<bits/stdc.h> using namespace std; const int N1010; struct Point …

【Linux系统编程】进程状态

介绍 进程的状态指的是进程在执行过程中所处的状态。进程的状态随着进程的执行和外界条件的变化而转换。我们可用 kill 命令来进程控制进程的状态。 kill中的 kill -l 指令用于查看系统中定义的所有信号及其对应的编号。这些信号可以用于 kill 命令来向进程发送特定的信号控制其…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Toast组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之Toast组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Toast组件 Toast 的应用场景也非常广泛&#xff0c;比如网络请求出错了可以弹一个…

鸿蒙开发(二)- 鸿蒙DevEco开发环境搭建

上篇说到&#xff0c;鸿蒙开发目前势头旺盛&#xff0c;头部大厂正在如火如荼地进行着&#xff0c;华为也对外宣称已经跟多个厂商达成合作。目前看来&#xff0c;对于前端或客户端开发人员来说&#xff0c;掌握下鸿蒙开发还是有些必要性的。如果你之前是从事Android开发的&…

论文阅读<CF-YOLO: Cross Fusion YOLO for Object Detection in Adverse Weather.....>

论文链接&#xff1a;https://arxiv.org/pdf/2309.08152.pdfhttps://arxiv.org/pdf/2206.01381.pdfhttps://arxiv.org/pdf/2309.08152.pdf 代码链接&#xff1a;https://github.com/DiffPrompter/diff-prompter 目前没有完整代码放出。 恶劣天气下的目标检测主要有以下三种解…

Stable Diffusion系列(三):网络分类与选择

文章目录 网络分类模型基座模型衍生模型二次元模型2.5D模型写实风格模型 名称解读 VAELora嵌入文件放置界面使用 网络分类 当使用SD webui绘图时&#xff0c;为了提升绘图质量&#xff0c;可以多种网络混合使用&#xff0c;可选的网络包括了模型、VAE、超网络、Lora和嵌入。 …

引用jquery.js的html5基础页面模板

本专栏是汇集了一些HTML常常被遗忘的知识&#xff0c;这里算是温故而知新&#xff0c;往往这些零碎的知识点&#xff0c;在你开发中能起到炸惊效果。我们每个人都没有过目不忘&#xff0c;过久不忘的本事&#xff0c;就让这一点点知识慢慢渗透你的脑海。 本专栏的风格是力求简洁…

使用LLaMA-Factory微调ChatGLM3

1、创建虚拟环境 略 2、部署LLaMA-Factory &#xff08;1&#xff09;下载LLaMA-Factory https://github.com/hiyouga/LLaMA-Factory &#xff08;2&#xff09;安装依赖 pip3 install -r requirements.txt&#xff08;3&#xff09;启动LLaMA-Factory的web页面 CUDA_VI…

Java经典框架之Spring MVC

Spring MVC Java 是第一大编程语言和开发平台。它有助于企业降低成本、缩短开发周期、推动创新以及改善应用服务。如今全球有数百万开发人员运行着超过 51 亿个 Java 虚拟机&#xff0c;Java 仍是企业和开发人员的首选开发平台。 课程内容的介绍 1. Spring MVC 入门案例 2. 基…

JVS低代码平台:多级菜单配置的详细教程与演示

多级菜单是软件系统一种常见的用户界面设计&#xff0c;它允许用户通过点击或选择不同的菜单项来执行不同的操作或访问不同的功能。多级菜单通常由多个级别的菜单组成&#xff0c;每个级别都包含一组可选择的菜单项。用户可以通过点击或选择菜单项来进入下一级菜单&#xff0c;…