强化学习:值函数近似【Deep Q-Network,DQN,Deep Q-learning】

强化学习笔记

主要基于b站西湖大学赵世钰老师的【强化学习的数学原理】课程,个人觉得赵老师的课件深入浅出,很适合入门.

第一章 强化学习基本概念
第二章 贝尔曼方程
第三章 贝尔曼最优方程
第四章 值迭代和策略迭代
第五章 强化学习实例分析:GridWorld
第六章 蒙特卡洛方法
第七章 Robbins-Monro算法
第八章 多臂老虎机
第九章 强化学习实例分析:CartPole
第十章 时序差分法
第十一章 值函数近似【DQN】


文章目录

  • 强化学习笔记
  • 一、状态值函数近似
  • 二、动作值函数的近似
    • 1 Deep Q-learning
  • 三、参考资料


在前面介绍的方法中,我们的 v ( s ) v(s) v(s) q ( s , a ) q(s,a) q(s,a)都是用如下表格形式来呈现的:

截屏2024-06-27 11.45.45

这对于 s , a s,a s,a离散且有限的时候是可行的,如果 s , a s,a s,a连续或者 ∣ A ∣ , ∣ S ∣ |\mathcal{A}|,|\mathcal{S}| A,S很大,前面介绍的算法,比如值迭代和策略迭代或者时序差分法就会面临两个问题:

  1. 计算过程需要的内存急剧上升;
  2. 泛化能力一般.

这时候我们引入一类新的方法——值函数估计,顾名思义, v ( s ) v(s) v(s) s s s的函数, q ( s , a ) q(s,a) q(s,a) s 、 a s、a sa的函数,所以我们要得到这两个函数,可以用一些机器学习的方法,比如常见的曲线拟合的方法,还可以结合深度学习的方法来进行对 v ( s ) , q ( s , a ) v(s),q(s,a) v(s),q(s,a)的估计。

一、状态值函数近似

首先我们介绍比较经典的线性方法,近似值函数 v ^ ( ⋅ , w ) \hat{v}(\cdot,\mathbf{w}) v^(,w)是权向量 w w w的线性函数。对应于每个状态 s s s,有一个实值向量 x ( s ) ≐ ( x 1 ( s ) , x 2 ( s ) , … , x d ( s ) ) \mathbf{x}(s)\doteq(x_1(s),x_2(s),\ldots,x_d(s)) x(s)(x1(s),x2(s),,xd(s)),与 w w w具有相同的维数。线性方法通过 w w w x ( s ) \mathbf{x}(s) x(s)的内积来近似状态值函数:
v ^ ( s , w ) ≐ w ⊤ x ( s ) ≐ ∑ i = 1 d w i x i ( s ) . ( 1 ) \hat{v}(s,\mathbf{w})\doteq\mathbf{w}^\top\mathbf{x}(s)\doteq\sum_{i=1}^dw_ix_i(s).\qquad(1) v^(s,w)wx(s)i=1dwixi(s).(1)
其中 x ( s ) x(s) x(s)是状态 s s s特征向量,这里介绍一种常见的多项式构造方法,更多的特征向量构造方法可以见参考文献2的9.5小节。假设 s s s是二维的,那么一个可能的构造为 x ( s ) = ( 1 , s 1 , s 2 , s 1 s 2 ) ∈ R 4 x(s)=(1,s_1,s_2,s_1s_2)\in\mathbb{R}^4 x(s)=(1,s1,s2,s1s2)R4,这就得到一个4维的特征向量,同理可以构造其他格式的多项式特征向量。

现在我们有了近似状态值函数的形式,那么如何估计(1)中的参数 w w w呢?这就用到机器学习里常用的方法——随机梯度下降。假设真实状态值函数为 v π ( s ) v_{\pi}(s) vπ(s),那么我们可以定义一个平方误差:
J ( w ) = E [ ( v π ( s ) − v ^ ( s , w ) ) 2 ] ( 2 ) J(w)=\mathbb{E}[(v_{\pi}(s)-\hat{v}(s,\mathbf{w}))^2]\qquad(2) J(w)=E[(vπ(s)v^(s,w))2](2)
这里的期望是对 s ∈ S s\in\mathcal{S} sS求的,一个常见的假设是所有 s s s均匀分布,那么可以得到:
J ( w ) = 1 ∣ S ∣ ∑ s ( v π ( s ) − v ^ ( s , w ) ) 2 J(w)=\frac{1}{|\mathcal{S}|}\sum_s(v_{\pi}(s)-\hat{v}(s,\mathbf{w}))^2 J(w)=S1s(vπ(s)v^(s,w))2
为了估计 w w w,我们的目标是最小化 J ( w ) J(w) J(w),由梯度下降法可得:
w k + 1 = w k − α k ∇ w J ( w k ) w_{k+1}=w_k-\alpha_k\nabla_wJ(w_k) wk+1=wkαkwJ(wk)
由(2)式推导梯度如下:
∇ w J ( w ) = ∇ w E [ ( v π ( s ) − v ^ ( s , w ) ) 2 ] = E [ ∇ w ( v π ( s ) − v ^ ( s , w ) ) 2 ] = 2 E [ ( v π ( s ) − v ^ ( s , w ) ) ( − ∇ w v ^ ( s , w ) ) ] = − 2 E [ ( v π ( s ) − v ^ ( s , w ) ) ∇ w v ^ ( s , w ) ] \begin{aligned} \nabla_wJ(w)&= \nabla_w\mathbb{E}[(v_\pi(s)-\hat{v}(s,w))^2] \\ &=\mathbb{E}[\nabla_w(v_\pi(s)-\hat{v}(s,w))^2] \\ &=2\mathbb{E}[(v_\pi(s)-\hat{v}(s,w))(-\nabla_w\hat{v}(s,w))] \\ &=-2\mathbb{E}[(v_\pi(s)-\hat{v}(s,w))\nabla_w\hat{v}(s,w)] \\ \end{aligned} wJ(w)=wE[(vπ(s)v^(s,w))2]=E[w(vπ(s)v^(s,w))2]=2E[(vπ(s)v^(s,w))(wv^(s,w))]=2E[(vπ(s)v^(s,w))wv^(s,w)]但这样计算梯度需要对所有 s s s求期望,不实用,所以采用随机梯度下降的方式,任取一个样本:
w t + 1 = w t + α t ( v π ( s t ) − v ^ ( s t , w t ) ) ∇ w v ^ ( s t , w t ) , ( 3 ) w_{t+1}=w_t+\alpha_t(v_\pi(s_t)-\hat v(s_t,w_t))\nabla_w\hat v(s_t,w_t),\qquad(3) wt+1=wt+αt(vπ(st)v^(st,wt))wv^(st,wt),(3)其中 s t ∈ S s_t\in\mathcal{S} stS,但是这个迭代格式还有一个问题,我们需要知道真实的 v π ( s ) v_{\pi}(s) vπ(s),显然我们是不知道的,并且我们要估计的就是这个 v π ( s ) v_{\pi}(s) vπ(s),那么我们用一个近似值来替代迭代格式中的 v π ( s ) v_{\pi}(s) vπ(s),有如下两种方法:

  1. 基于蒙特卡洛学习的状态值函数逼近
    假设 g t g_t gt为某个episode里的从 s t s_t st开始的累积折扣回报。那么我们可以用 g t g_t gt来近似 v π ( s t ) v_\pi(s_t) vπ(st). 迭代格式(3)算法变为
    w t + 1 = w t + α t ( g t − v ^ ( s t , w t ) ) ∇ w v ^ ( s t , w t ) . w_{t+1}=w_t+\alpha_t(g_t-\hat{v}(s_t,w_t))\nabla_w\hat{v}(s_t,w_t). wt+1=wt+αt(gtv^(st,wt))wv^(st,wt).
  2. 基于TD学习的状态值函数逼近
    结合TD学习方法,将 r t + 1 + γ v ^ ( s t + 1 , w t ) r_{t+1}+\gamma\hat{v}(s_{t+1},w_t) rt+1+γv^(st+1,wt)视为对 v π ( s t ) v_\pi(s_{t}) vπ(st). 的一种近似。因此,迭代格式(3)可表示为
    w t + 1 = w t + α t [ r t + 1 + γ v ^ ( s t + 1 , w t ) − v ^ ( s t , w t ) ] ∇ w v ^ ( s t , w t ) . ( 4 ) w_{t+1}=w_ t + \alpha _ t [ r_ { t + 1 } + \gamma \hat { v } ( s_ { t + 1 }, w_ t ) - \hat { v } ( s_ t, w_ t ) ] \nabla _ w \hat { v } ( s_ t, w_ t ).\qquad(4) wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt).(4)

二、动作值函数的近似

上面(4)给出了结合TD learning来求 v ^ ( s , w ) \hat{v}(s,w) v^(s,w)的迭代格式,显然我们可以把 ( 4 ) (4) (4)改成SARSA或者Q-learning就能得到基于TD学习的动作值函数逼近的迭代格式,此处不再赘述,详情参考强化学习:时序差分法.下面来介绍深度强化学习里面的一个经典的模型——Deep Q-learning,也叫Deep Q-Network(DQN).

1 Deep Q-learning

DQN方法旨在最小化如下的目标函数:
J ( w ) = E [ ( R + γ max ⁡ a ∈ A ( S ′ ) q ^ ( S ′ , a , w ) − q ^ ( S , A , w ) ) 2 ] , J(w)=\mathbb{E}\left[\left(R+\gamma\max_{a\in\mathcal{A}(S')}\hat{q}(S',a,w)-\hat{q}(S,A,w)\right)^2\right], J(w)=E[(R+γaA(S)maxq^(S,a,w)q^(S,A,w))2],
参数 w w w出现在两个地方,求导不好求,所以DQN的一个核心思想是:采用两个网络来分别近似 q ^ ( S ′ , a , w ) \hat{q}(S',a,w) q^(S,a,w) q ^ ( S , A , w ) \hat{q}(S,A,w) q^(S,A,w),在更新参数时,先把 q ^ ( S ′ , a , w ) \hat{q}(S',a,w) q^(S,a,w)看做固定值,那么 J ( w ) J(w) J(w)就只有一个地方有 w w w求导就相对容易,可以利用梯度下降对近似 q ^ ( S , A , w ) \hat{q}(S,A,w) q^(S,A,w)网络的参数进行更新。那么如何对近似 q ^ ( S ′ , a , w ) \hat{q}(S',a,w) q^(S,a,w)目标网络的参数进行更新呢?DQN提出可以设置一个参数 C C C,每迭代 C C C次,我们将 q ^ ( S , A , w ) \hat{q}(S,A,w) q^(S,A,w)网络的参数复制给target network,这样进行交替更新,最终可以得到一个近似动作值函数。

同时,提出DQN模型的论文还开创性的提出经验回放的技巧,简单来说就是将采样得到的数据 ( S , A , R , S ′ ) (S,A,R,S') (S,A,R,S)放入一个经验缓冲区 D D D,训练神经网络时就用 D D D里面的数据进行训练,这样做的好处是可以去除观测序列中的相关性并对数据分布的变化进行平滑。

下面是DQN算法的伪代码:

截屏2024-06-27 13.19.24

自从DQN的提出,研究人员对其进行了多种改进,如Double DQN、Dueling DQN和Prioritized Experience Replay等,这些改进进一步提升了DQN的性能和稳定性。DQN的提出是深度强化学习领域的重要里程碑,它展示了深度学习在强化学习中的巨大潜力,并为后续研究奠定了基础。

学习了DQN的理论知识后,具体如何实现参考我的这篇文章:基于强化学习DQN的股票预测【DQN的Python实践】.通过具体的代码,我们可以更加深入的理解DQN模型的构造以及实现细节.

三、参考资料

  1. Zhao, S… Mathematical Foundations of Reinforcement Learning. Springer Nature Press and Tsinghua University Press.
  2. Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018.
  3. Mnih, V., Kavukcuoglu, K.(2015). Human-level control through deep reinforcement learning. Nature, 518(7540), 529-533.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/36793.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(番外篇)指针的一些相关习题讲解(速进,干货满满)(1)

前言: 我已经好久没写过博客了,这几天确实有点偷懒了,上次博客我们已经讲完了指针的部分内容,但我觉着没有习题是不够的,于是我出了这一篇番外篇,来让各位读者朋友们进行指针强化,这些题目都是小…

Python27 神经网络中的重要概念和可视化实现

1. 神经网络背后的直观知识 神经网络的工作方式非常相似:它接受多个输入,经过多个隐藏层中的多个神经元进行处理,并通过输出层返回结果,这个过程在技术上称为“前向传播”。 接下来,将神经网络的输出与实际输出进行比…

GIT-LFS使用

0.前言 目前git仓库有很多很大的文件需要管理,但是直接上传,每次clone的文件太大,所有准备使用git-lfs解决。 1、下载和安装 Git LFS 1.1、直接下载二进制包: Releases git-lfs/git-lfs GitHub 安装 Git LFS sudo rpm -ivh…

Spring Boot中获取请求参数的几种方式

前言 在构建现代 Web 应用时,处理来自客户端的请求参数是不可或缺的一部分。Spring Boot作为构建微服务应用的领先框架,提供了多种灵活高效的方式来获取请求参数,满足各种应用场景。 无论您是Spring Boot的初学者,还是希望更深入…

LabVIEW电涡流检测系统

开发了一种基于LabVIEW的软件与硬件结合的电涡流检测系统,通过同步采样技术和编码器的协同工作,显著提高了大型结构物的损伤检测精度和效率,具有良好的应用前景和实用价值。 项目背景 传统的手持式电涡流检测方法因其速度慢、灵敏度低、准确…

<sa8650>QCX 诊断模块和错误处理

<sa8650>QCX 诊断模块和错误处理 一、错误报告设计二、QCarCam API 的错误报告2.1 QCarCamRegisterEventCallback2.2 CarCamErrorInfo_t2.3 QCarCamErrorInfo_t2.4 Error ID2.4.1 QCARCAM_ERROR_WARNING2.4.2 QCARCAM_ERROR_SUBSYSTEM_FATAL2.4.3 QCARCAM_ERROR_FATAL2.4.4 Q…

Links: Challenging Puzzle Game Template(益智游戏模板)

链接:挑战益智游戏 《Links》是一款独特且具有挑战性的益智游戏,即将发布。 每个级别都会向玩家展示不同的棋盘。目标是通过移动和旋转所有棋子来连接它们。每个棋子都有自己的特点和功能-你可以移动它们,旋转它们,或者两者兼而有之。连接所有棋子,以解决难度和挑战不断增…

谷歌发布两款新Gemma 2大语言模型;阿里云开源Qwen2-72B模型荣登榜首

🦉 AI新闻 🚀 谷歌发布两款新Gemma 2大语言模型 摘要:谷歌发布Gemma 2大语言模型,包括90亿和270亿参数两种版本。Gemma 2在推理性能、效率和安全性上较第一代有显著提升。27B模型的性能媲美更大规模的主流模型,且部署…

收银系统开源源码-千呼新零售2.0【打折促销】

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货、宠物等连锁店使用。 详细介绍请…

OpenAI穿着「皇帝的新衣」;扒了数万条帖子汇总100种AIGC玩法;北美出海的财务避坑指南;我创业「如」有CTO | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 1. 我扒了 Reddit 论坛数万条帖子,汇总了 GenAI 的 100 种玩法 ChatGPT 已经问世一年半了。这期间诞生了很多大语言模型和生成式人工智能…

[数据集][目标检测]金属架螺栓螺丝有无检测数据集VOC+YOLO格式857张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):857 标注数量(xml文件个数):857 标注数量(txt文件个数):857 标注类别…

Vite: 关于Rollup打包

概述 Rollup 是一款基于 ES Module 模块规范实现的 JavaScript 打包工具,在前端社区中赫赫有名,同时也在 Vite 的架构体系中发挥着重要作用不仅是 Vite 生产环境下的打包工具,其插件机制也被 Vite 所兼容,可以说是 Vite 的构建基…

数据结构速成--树和二叉树

由于是速成专题,因此内容不会十分全面,只会涵盖考试重点,各学校课程要求不同 ,大家可以按照考纲复习,不全面的内容,可以看一下小编主页数据结构初阶的内容,找到对应专题详细学习一下。 气死了…

东京裸机云服务器怎么用

东京裸机云服务器是一种结合了物理服务器性能和云服务灵活性的高性能计算服务,它为用户提供了高效、安全的计算和存储能力。在了解如何使用东京裸机云服务器之前,需要了解其基本特性和优势。具体分析如下,rak部落小编为您整理发布。 1. **硬件…

代码随想录第36天|动态规划

62. 不同路径 补充: 对二维数组的操作 dp[j][i] 表示到 j,i 有多少种路径递推公式: dp[j][i] dp[j - 1][i] dp[j][i - 1]初始化: dp[0][i] 和 dp[j][0] 都只有1种情况遍历顺序: 由于dp[j][i] 由 上和左的元素推导, 所以采用从左到右、从上到下的遍历顺序 class Solution {…

怎么隐藏宝塔面板左上角绑定的手机号码?

宝塔面板后台的左上角会显示我们绑定的宝塔账号(手机号码),每次截图的时候都要去抹掉这个号码,那么能不能直接将这个手机号码隐藏掉呢? 如上图红色箭头所示的手机号码,其实就是我们绑定的宝塔账号&#xff…

Delphi-2M:基于病史预测未来健康的改进GPT架构

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

CAN-bus总线在冷链运输中的应用

CAN-bus总线在冷链运输中的应用 如图1所示,疫苗冷链是指为保证疫苗从疫苗生产企业到接种单位运转过程中的质量而装备的存储、运输冷藏设施、设备。由于疫苗对温度敏感,从疫苗制造的部门到疫苗使用的现场之间的每一个环节,都可能因温度过高而失效。在储运过程中,一旦温度超…

R语言 | 带P值的相关性热图绘制教程

原文链接:带P值的相关性热图绘制教程 本期教程 往期教程部分内容 **注意:若是在MarkDown格式中无法运行成功,请新建有一个R script文件 ** 一、加载R包 if (!require(corrplot)) install.packages("corrplot") if (!require(Hmi…

【python】PyQt5对象类型的判定,对象删除操作详细解读

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…