强化学习的数学原理学习笔记 - 值函数近似(Value Function Approximation)

文章目录

  • 概览:RL方法分类
  • 值函数近似(Value function approximation)
    • Basic idea
      • 目标函数(objective function)
      • 优化算法(optimization algorithm)
    • Sarsa / Q-learning with function approximation
      • Sarsa with function approximation
      • Q-learning with function approximation
    • 🟦DQN (Deep Q-learning)
      • 关键技术1:两个网络
      • 关键技术2:经验回放(Experience replay)
      • DQN算法步骤(off-policy)


本系列文章介绍强化学习基础知识与经典算法原理,大部分内容来自西湖大学赵世钰老师的强化学习的数学原理课程(参考资料1),并参考了部分参考资料2、3的内容进行补充。

系列博文索引:

  • 强化学习的数学原理学习笔记 - RL基础知识
  • 强化学习的数学原理学习笔记 - 基于模型(Model-based)
  • 强化学习的数学原理学习笔记 - 蒙特卡洛方法(Monte Carlo)
  • 强化学习的数学原理学习笔记 - 时序差分学习(Temporal Difference)
  • 强化学习的数学原理学习笔记 - 值函数近似(Value Function Approximation)
  • 强化学习的数学原理学习笔记 - 策略梯度(Policy Gradient)
  • 强化学习的数学原理学习笔记 - Actor-Critic

参考资料:

  1. 【强化学习的数学原理】课程:从零开始到透彻理解(完结)(主要)
  2. Sutton & Barto Book: Reinforcement Learning: An Introduction
  3. 机器学习笔记

*注:【】内文字为个人想法,不一定准确

概览:RL方法分类

图源:https://zhuanlan.zhihu.com/p/36494307
*图源:https://zhuanlan.zhihu.com/p/36494307

值函数近似(Value function approximation)

在先前的方法中,状态/动作值均以表格的(tabular)形式呈现。但是当状态/动作空间较大或者连续时,以上算法会面临存储开销和泛化能力的问题。因此,考虑通过特定函数的形式近似状态值。

Basic idea

*A simple example
假设状态值 v ( s ) v(s) v(s)与状态 s s s之间呈线性关系,设 v ^ ( s , w ) \hat{v}(s, w) v^(s,w)是对 v ( s ) v(s) v(s)的估计,则有下式:
image.png
其中, w w w为参数向量, ϕ ( s ) \phi(s) ϕ(s)为状态 s s s的特征向量。
这样做的好处在于大大降低了存储开销:不需要存储每个状态值,只需要存储 w w w(即 a a a b b b两个参数)即可。但是弊端在于通过函数近似得到的结果并不一定准确。这种思想可以继续推广到高阶及非线性函数,以提升估计的准确性。

值函数近似的idea:使用参数化(parameterized)的函数近似状态和动作值,即 v ^ ( s , w ) ≈ v π ( s ) \hat{v}(s, w) \approx v_\pi(s) v^(s,w)vπ(s),其中 w ∈ R m w \in \mathbb{R}^m wRm是参数向量。

好处:(1) 便于存储:只需要存储参数,不需要存储状态,而参数的维度往往远小于状态的数量;(2) 泛化能力:当访问一个状态后,参数值发生改变,则整个函数估计发生改变,其余未被访问的状态的状态值同样会发生改变,因此不需要访问每个状态来完成学习过程。

目标函数(objective function)

值函数近似的目标是使得估计值尽可能接近真实状态值,其目标函数为:
J ( w ) = E [ ( v π ( S ) − v ^ ( S , w ) ) 2 ] J(w) = \mathbb{E} [ (v_\pi(S) - \hat{v}(S,w))^2 ] J(w)=E[(vπ(S)v^(S,w))2]
值函数近似的目标,即找到能够使得 J ( w ) J(w) J(w)最小的 w w w。本质上是做策略评估中的状态值估计。
其中, S ∈ S S \in \mathcal{S} SS为随机变量,其概率分布为平稳分布(stationary distribution),描述长期行为(long-run behavior),也被称为steady-state distribution或limiting distribution【*一个随机过程/马尔可夫过程中的概念】。直观理解:如果一个agent按照一个给定策略运行了足够久,其马尔可夫过程最终会达到一个平稳状态【即模型(状态转移概率)是稳定的】。

{ d π ( s ) } s ∈ S \{ d_\pi (s) \}_{s\in \mathcal{S}} {dπ(s)}sS表示策略 π \pi π下的马尔可夫过程的平稳分布,有 d π ( s ) ≥ 0 d_\pi (s) \geq 0 dπ(s)0 ∑ s ∈ S d π ( s ) = 1 \textstyle \sum_{s\in \mathcal{S}} d_\pi(s) =1 sSdπ(s)=1。则值函数近似的目标函数可以写作:
J ( w ) = E [ ( v π ( S ) − v ^ ( S , w ) ) 2 ] = ∑ s ∈ S d π ( s ) ( v π ( S ) − v ^ ( S , w ) ) 2 J(w) = \mathbb{E} [ (v_\pi(S) - \hat{v}(S,w))^2 ] = \sum_{s\in \mathcal{S}} d_\pi (s) (v_\pi(S) - \hat{v}(S,w))^2 J(w)=E[(vπ(S)v^(S,w))2]=sSdπ(s)(vπ(S)v^(S,w))2
其中, d π ( s ) d_\pi (s) dπ(s)表示agent处于状态 s s s的概率,同时也是该状态的(重要性)权重值,因此上式可以看作是对不同状态的估计误差的平方的加权平均。

优化算法(optimization algorithm)

采用随机梯度下降(SGD)算法优化(最小化)目标函数 J ( w ) J(w) J(w)(推导过程略):
w t + 1 = w t − α t ( v π ( s t ) − v ^ ( s t , w t ) ) ∇ w v ^ ( s t , w t ) w_{t+1} = w_t - \alpha_t (v_\pi(s_t) - \hat{v}(s_t, w_t)) \nabla_w \hat{v}(s_t, w_t) wt+1=wtαt(vπ(st)v^(st,wt))wv^(st,wt)

注意到其中 v π ( s t ) v_\pi(s_t) vπ(st)是未知的,其可以用MC或TD近似:

  • MC with 值函数近似:用 g t g_t gt(从 s t s_t st出发的累计折扣回报)近似 v π ( s t ) v_\pi(s_t) vπ(st)
    • w t + 1 = w t − α t ( g t − v ^ ( s t , w t ) ) ∇ w v ^ ( s t , w t ) w_{t+1} = w_t - \alpha_t (g_t - \hat{v}(s_t, w_t)) \nabla_w \hat{v}(s_t, w_t) wt+1=wtαt(gtv^(st,wt))wv^(st,wt)
  • TD with 值函数近似:用 r t + 1 + γ v ^ ( s t + 1 , w t ) r_{t+1} + \gamma \hat{v}(s_{t+1}, w_t) rt+1+γv^(st+1,wt)近似 v π ( s t ) v_\pi(s_t) vπ(st)
    • w t + 1 = w t − α t [ r t + 1 + γ v ^ ( s t + 1 , w t ) − v ^ ( s t , w t ) ] ∇ w v ^ ( s t , w t ) w_{t+1} = w_t - \alpha_t [r_{t+1} + \gamma \hat{v}(s_{t+1}, w_t) - \hat{v}(s_t, w_t)] \nabla_w \hat{v}(s_t, w_t) wt+1=wtαt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)
      • TD target: r t + 1 + γ v ^ ( s t + 1 , w t ) r_{t+1} + \gamma \hat{v}(s_{t+1}, w_t) rt+1+γv^(st+1,wt)
    • *实际上这种方法并不是在优化原本的目标函数,而是在优化另一个相关的目标函数,称作projected Bellman error(详细内容略)

v ^ ( s , w ) \hat{v} (s, w) v^(s,w)的形式选择:早期用线性函数,目前通用神经网络(Neural Network,NN)来拟合未知非线性函数。线性函数的好处在于其理论性非常容易分析,弊端在于其特征向量(比如其阶数)难以选择。
*若 v ^ ( s , w ) \hat{v} (s, w) v^(s,w)为线性函数,则其等价于tabular representation,因此可以将tabular representation看作linear function approximation的一种特殊情况。

Sarsa / Q-learning with function approximation

Sarsa with function approximation

其实就是把TD with function approximation中的状态值换为动作值:
w t + 1 = w t + α t [ r t + 1 + γ q ^ ( s t + 1 , a t + 1 , w t ) − q ^ ( s t , a t , w t ) ] ∇ w q ^ ( s t , a t , w t ) w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w_t) - \hat{q}(s_t, a_t, w_t)] \nabla_w \hat{q}(s_t, a_t, w_t) wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt)

和Tabular Sarsa的区别:不是直接更新动作值 q ( s , a ) q(s,a) q(s,a),而是更新参数值 w w w

采用ε-Greedy方法进行策略提升:
π k + 1 ( a ∣ s t ) = { 1 − ε ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) if  a = arg max ⁡ a ∈ A ( s t ) q ^ ( s t , a , w t + 1 ) ε ∣ A ( s ) ∣ otherwise \pi_{k+1}(a|s_t) = \begin{cases} 1-\frac{\varepsilon}{|\mathcal{A} (s)|} (|\mathcal{A}(s)|-1) &\text{if } a = \argmax_{a\in\mathcal{A(s_t)}} \hat{q}(s_t, a, w_{t+1}) \\ \frac{\varepsilon}{|\mathcal{A}(s)|} &\text{otherwise} \end{cases} πk+1(ast)={1A(s)ε(A(s)1)A(s)εif a=argmaxaA(st)q^(st,a,wt+1)otherwise
注意其中的 q ^ ( s t , a , w t + 1 ) \hat{q}(s_t, a, w_{t+1}) q^(st,a,wt+1)需要通过函数计算得到。

Q-learning with function approximation

w t + 1 = w t + α t [ r t + 1 + γ max ⁡ a ∈ A ( s t + 1 ) q ^ ( s t + 1 , a t , w t ) − q ^ ( s t , a t , w t ) ] ∇ w q ^ ( s t , a t , w t ) w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma {\color{red} \max_{a \in \mathcal{A}(s_{t+1})} \hat{q}(s_{t+1}, a_{t}, w_t)} - \hat{q}(s_t, a_t, w_t)] \nabla_w \hat{q}(s_t, a_t, w_t) wt+1=wt+αt[rt+1+γaA(st+1)maxq^(st+1,at,wt)q^(st,at,wt)]wq^(st,at,wt)

🟦DQN (Deep Q-learning)

尽管在Q-learning with function approximation中,可以使用神经网络作为 q ^ ( s , a , w ) \hat{q} (s, a, w) q^(s,a,w),但其需要复杂的底层运算(如求梯度),因此提出了DQN(Deep Q-learning / Deep Q Network)作为替代。

DQN的目标函数/损失(loss)函数
J ( w ) = E [ ( R + γ max ⁡ α ∈ A ( S ′ ) q ^ ( S ′ , a , w ) − q ^ ( S , A , w ) ) 2 ] J(w) = \mathbb{E} \Big[ \Big(R + \gamma \max_{\alpha \in \mathcal{A} (S') } \hat{q} (S', a, w) - \hat{q} (S, A ,w) \Big) ^2 \Big] J(w)=E[(R+γαA(S)maxq^(S,a,w)q^(S,A,w))2]
其中, ( S , A , R , S ′ ) (S,A,R,S') (S,A,R,S)均为随机变量, R + γ max ⁡ α ∈ A ( S ′ ) q ^ ( S ′ , a , w ) − q ^ ( S , A , w ) R + \gamma \max_{\alpha \in \mathcal{A} (S') } \hat{q} (S', a, w) - \hat{q} (S, A ,w) R+γmaxαA(S)q^(S,a,w)q^(S,A,w)为Q-learning的TD error,也即Bellman optimlity error,当该值为0时取得最优。

关键技术1:两个网络

直接采用梯度下降优化损失函数并不容易,因为其中两项都包含 w w w,求梯度较复杂。一个简单的思路是,将 y = R + γ max ⁡ α ∈ A ( S ′ ) q ^ ( S ′ , a , w ) y = R + \gamma \textstyle \max_{\alpha \in \mathcal{A} (S') } \hat{q} (S', a, w) y=R+γmaxαA(S)q^(S,a,w)视作常数,只需求解 q ^ ( S , A , w ) \hat{q} (S, A ,w) q^(S,A,w)的梯度即可。
因此,DQN引入了两个网络的设计:

  • main network:对应 q ^ ( S , A , w ) \hat{q} (S, A ,w) q^(S,A,w)
  • target network:对应 q ^ ( S ′ , a , w T ) \hat{q} (S', a, w_T) q^(S,a,wT)

main network的参数 w w w实时更新,但target network的参数 w T w_T wT并非实时更新,而是隔一段时间把main network的 w w w赋值过来,因此在这段时间内, w T w_T wT可以被视为常数。

DQN的basic idea:使用梯度下降(GD)优化损失函数,对应梯度为:
∇ w J = E [ ( R + γ max ⁡ α ∈ A ( S ′ ) q ^ ( S ′ , a , w T ) − q ^ ( S , A , w ) ) ∇ w q ^ ( S , A , w ) ] \nabla_w J = \mathbb{E} \Big[ \Big(R + \gamma \max_{\alpha \in \mathcal{A} (S') } {\color{red} \hat{q} (S', a, w_T) } - {\color{blue} \hat{q} (S, A ,w) } \Big) {\color{blue} \nabla_w \hat{q} (S, A ,w) } \Big] wJ=E[(R+γαA(S)maxq^(S,a,wT)q^(S,A,w))wq^(S,A,w)]

训练过程(详见下):在每次迭代中,DQN从回放缓存(replay buffer)中取mini-batch采样 { ( s , a , r , s ′ ) } \{(s, a, r, s')\} {(s,a,r,s)},以 s s s a a a作为输入计算得到 y T = r + γ max ⁡ α ∈ A ( s ′ ) q ^ ( s ′ , a , w T ) y_T = r + \gamma \textstyle \max_{\alpha \in \mathcal{A} (s') } \hat{q} (s', a, w_T) yT=r+γmaxαA(s)q^(s,a,wT),并基于mini-batch { ( s , a , y T ) } \{ (s, a, y_T) \} {(s,a,yT)}最小化损失函数 ( y T − q ^ ( s , a , w ) ) 2 ( y_T - \hat{q} (s, a ,w) )^2 (yTq^(s,a,w))2以训练main network。之后,将main network的参数 w w w赋值给target network的 w T w_T wT

关键技术2:经验回放(Experience replay)

DQN在收集经验采样后,将其存储在回放缓存(replay buffer) B = { ( s , a , r , s ′ ) } \mathcal{B} = \{ (s, a, r, s') \} B={(s,a,r,s)}中。当需要使用采样训练神经网络时,从回放缓存中按照均匀分布(uniform distribution)随机取mini-batch的采样 ,该过程称为经验回放。

  • 均匀分布:对所有 ( s , a ) (s, a) (s,a)对等概率访问(不等概率的话,需要先验知识才能确定哪些 ( s , a ) (s, a) (s,a)对更重要)
    • 这里是把 ( S , A ) (S,A) (S,A)对看作一个随机变量
  • 回放缓存:经验采样的采集有先后顺序,直接按照其顺序使用可能不满足均匀分布的要求,因此将过往经验先存起来再均匀采样,去除采样间的相关性

*实际上经验回放也可以用于tabular Q-learning中,还能提高其采样效率(因为可以重复利用)。

DQN算法步骤(off-policy)

目标:从行为策略 π b \pi_b πb生成的经验采样中,学习一个最优的target network以近似最优动作值
在每次迭代中:

  1. 从回放缓存 B \mathcal{B} B中均匀取mini-batch采样
  2. 对于每个采样 ( s , a , r , s ′ ) (s, a, r, s') (s,a,r,s),计算 y T = r + γ max ⁡ α ∈ A ( s ′ ) q ^ ( s ′ , a , w T ) y_T = r + \gamma \textstyle \max_{\alpha \in \mathcal{A} (s') } \hat{q} (s', a, w_T) yT=r+γmaxαA(s)q^(s,a,wT),其中 w T w_T wT为target network的参数
  3. 使用mini-batch采样 { ( s , a , y T ) } \{ (s, a, y_T) \} {(s,a,yT)}更新main network,以最小化损失函数 ( y T − q ^ ( s , a , w ) ) 2 ( y_T - \hat{q} (s, a ,w) )^2 (yTq^(s,a,w))2

C C C次迭代后,将 w w w赋值给 w T w_T wT
*注意:这里的表述与DQN原论文不同(原论文的NN更高效),但本质是一样的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/607636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学生备考哪款护眼台灯好?2024五款知名品牌强力推荐

最近应后台小伙伴要求,给大家测评一些护眼台灯产品,毕竟现在的孩子近视人数真的非常多,每五个孩子戴眼镜的就有三个了,日常学习中保护视力,由于很多学习时间都是在晚上,台灯成为了为陪伴学习不可或缺的搭档…

深度学习:图神经网络——在推荐系统中的应用

PinSage是工业界应用图神经网络完成推荐任务的第一个成功案例,其从用户数据中构造图(graph)的方法和应对大规模图而采取的实现技巧都值得我们学习。PinSage被应用在图片推荐类Pinterest上。在Pinterest中,每个用户可以创建并命名图…

TikTok电商年度洞察:出海到底“卖什么”?各国多类目爆款洞察,迅速掌握市场领先优势

很多卖家在尝试出海时,常面临两大核心痛点:一是“卖什么”,即选择何种商品进行销售;二是“怎么卖”,即如何通过有效的营销策略将商品销售出去。TikTok主打的内容电商模式,通过短视频和直播等形式&#xff0…

StampedLock锁探究

该锁提供了三种模式的读写控制,当调用获取锁的系列函数时,会返回一个long型的变量,我们称之为戳记(stamp),这个戳记代表了锁的状态。 其中try系列获取锁的函数,当获取锁失败后会返回为0的stamp 值。 当调用释放锁和转换锁的方法…

汽车中的ECU、VCU、MCU、HCU

一、ECU是汽车电脑,刷汽车电脑可以提高动力,也可以减低动力,看需求。 简单原理如下。 1.汽车发动机运转由汽车电脑(即ECU)控制。 2.ECU控制发动机的进气量,喷油量,点火时间等,从而…

成功解决使用git clone下载失败的问题: fatal: 过早的文件结束符(EOF) fatal: index-pack 失败

一.使用 http 可能出现的问题和解决 1.问题描述 ~$ git clone https://github.com/oKermorgant/ecn_baxter_vs.git 正克隆到 ecn_baxter_vs... remote: Enumerating objects: 13, done. remote: Counting objects: 100% (13/13), done. remote: Compressing objects: 100% (…

强直性脊柱炎=“不死的癌症”?这些常识你不可不知→

对强直性脊柱炎这个疾病,大家最常听说的是:强直性脊柱炎症状重、治疗难,会逐渐引发关节畸形、功能丧失,甚至残疾,被称为「不死的癌症」。 然而,近来越来越多患有强直性脊柱炎的明星活跃在荧幕上&#xff0c…

材料表征的微观探测器——台阶高度测量技术概述

一、引言 表面特征是材料、化学等领域的不可或缺的主要研究内容,合理地评价表面形貌、表面特征等,对于相关材料的评定、性能的分析和加工条件的改善都具有重要的意义。 表面台阶高度测量在材料表面研究中有十分重要的作用。一方面,表面测量…

x-cmd pkg | busybox - 嵌入式 Linux 的瑞士军刀

目录 简介首次用户功能特点竞品和相关作品 进一步阅读 简介 busybox 是一个开源的轻量级工具集合,集成了一批最常用 Unix 工具命令,只需要几 MB 大小就能覆盖绝大多数用户在 Linux 的使用,能在多款 POSIX 环境的操作系统(如 Linu…

避免重复扣款:分布式支付系统的幂等性原理与实践

这是《百图解码支付系统设计与实现》专栏系列文章中的第(6)篇。 本文主要讲清楚什么是幂等性原理,在支付系统中的重要应用,业务幂等、全部幂等这些不同的幂等方案选型带来的收益和复杂度权衡,幂等击穿场景及可能的严重…

k8s源码阅读环境配置

源码阅读环境配置 k8s代码的阅读可以让我们更加深刻的理解k8s各组件的工作原理,同时提升我们Go编程能力。 IDE使用Goland,代码阅读环境需要进行如下配置: 从github上下载代码:https://github.com/kubernetes/kubernetes在GOPATH目…

CTF-PWN-沙箱逃脱-【seccomp和prtcl-2】

文章目录 沙箱逃脱prtcl题HITCON CTF 2017 Quals Impeccable Artifactflag文件对应prctl函数检查源码思路exp 沙箱逃脱prtcl题 HITCON CTF 2017 Quals Impeccable Artifact flag文件 此时的flag文件在本文件夹建一个即可 此时的我设置的flag为 对应prctl函数 第一条是禁止…

JavaScript解构赋值完全手册

🧑‍🎓 个人主页:《爱蹦跶的大A阿》 🔥当前正在更新专栏:《VUE》 、《JavaScript保姆级教程》、《krpano》 ​ 目录 ✨ 前言 第一节:解构赋值的基本用法 第二节:对象解构赋值 第三节:数组解构赋值 第四节:参数…

Fluids —— MicroSolvers DOP

目录 Gas SubStep —— 重复执行对应的子步 Switch Solver —— 切换解算器 Gas Attribute Swap —— 交换、复制或移动几何体属性 Gas Intermittent Solve —— 固定时间间隔计算子解算器 Gas External Forces —— 计算外部力并更新速度或速度场 Gas Particle Separate…

【linux】tcpdump 使用

tcpdump 是一个强大的网络分析工具,可以在 UNIX 和类 UNIX 系统上使用,用于捕获和分析网络流量。它允许用户截取和显示发送或接收过网络的 TCP/IP 和其他数据包。 一、安装 tcpdump 通常是默认安装在大多数 Linux 发行版中的。如果未安装,可…

竞赛保研 基于深度学习的人脸表情识别

文章目录 0 前言1 技术介绍1.1 技术概括1.2 目前表情识别实现技术 2 实现效果3 深度学习表情识别实现过程3.1 网络架构3.2 数据3.3 实现流程3.4 部分实现代码 4 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的人脸表情识别 该项目较…

prometheus 黑盒监控

黑盒监控 “白盒监控” 是需要把对应的Exporter程序安装到被监控的目标主机上,从而实现对主机各种资源以及状态的数据采集工作 ”黑盒监控“ 是不需要把Exporter程序部署到被监控的目标主机上,比如全球的网络质量的稳定性,通常用ping操作&am…

2019年认证杯SPSSPRO杯数学建模A题(第一阶段)好风凭借力,送我上青云全过程文档及程序

2019年认证杯SPSSPRO杯数学建模 纸飞机在飞行状态下的运动模型 A题 好风凭借力,送我上青云 原题再现: 纸飞机有许多种折法。世界上有若干具有一定影响力的纸飞机比赛,通常的参赛规定是使用一张特定规格的纸,例如 A4 大小的纸张…

数据结构——队列(Queue)

目录 1.队列的介绍 2.队列工程 2.1 队列的定义 2.1.1 数组实现队列 2.1.2 单链表实现队列 2.2 队列的函数接口 2.2.1 队列的初始化 2.2.2 队列的数据插入(入队) 2.2.3 队列的数据删除(出队) 2.2.4 取队头数据 2.2.5 取队…

python匹配问题

脏数据匹配 一般数据建模步骤中,数据清洗耗时占比80%以上,因为现实中接触到的数据相当脏,无法直接简单的用pandas的merge函数解决。下面以QS大学排名的匹配为例,简单介绍脏数据匹配中会遇到的问题和主要步骤。 1 问题描述 给定…