IEEE TAI 2024
paper
1 Introduction
一篇offline to online 的文章,有效解决迁移过程出现的performance drop。所提出的O2AC算法首先在离线阶段添加一项BC惩罚项,用于限制策略靠近专家策略;而在在线微调阶段,通过动态调整BC的权重,缓解performance drop。
2 Method
2.1 offline
离线阶段,采用BC结合确定性策略优化方法。最大化下列损失函数:
J o f f i n e ( θ ) = E ( s , a ) ∼ B [ ζ Q ϕ ( s , π θ ( s ) ) − ∥ π θ ( s ) − a ∥ 2 ] J_{\mathrm{offine}}(\boldsymbol{\theta})=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\zeta Q_{\boldsymbol{\phi}}(\boldsymbol{s},\pi_{\boldsymbol{\theta}}(\boldsymbol{s}))-\left\|\pi_{\boldsymbol{\theta}}(\boldsymbol{s})-\boldsymbol{a}\right\|^2\right] Joffine(θ)=E(s,a)∼B[ζQϕ(s,πθ(s))−∥πθ(s)−a∥2]
其中, ζ \zeta ζ用于平衡BC以及一般policy iteration,其数值如下:
ζ = α 1 m ∑ ( s i , a i ) ∈ B ‾ ∣ Q ( s i , a i ) ∣ \zeta=\frac{\alpha}{\frac1m\sum_{(\boldsymbol{s}_i,\boldsymbol{a}_i)\in\overline{\mathcal{B}}}|Q(\boldsymbol{s}_i,\boldsymbol{a}_i)|} ζ=m1∑(si,ai)∈B∣Q(si,ai)∣α
其中 B ‾ \overline{\mathcal{B}} B表示从Buffer中采样地mini-batch, size为m
2.2 online
在线微调阶段,对确定性策略优化的损失函数表示如下
J o n l i n e ( θ ) = E ( s , a ) ∼ B [ ζ Q ϕ ( s , π θ ( s ) ) − λ ∥ π θ ( s ) − a ∥ 2 ] J_{\mathrm{online}}(\boldsymbol{\theta})=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\zeta Q_{\boldsymbol{\phi}}(\boldsymbol{s},\pi_{\boldsymbol{\theta}}(\boldsymbol{s}))-\lambda\left\|\pi_{\boldsymbol{\theta}}(\boldsymbol{s})-\boldsymbol{a}\right\|^2\right] Jonline(θ)=E(s,a)∼B[ζQϕ(s,πθ(s))−λ∥πθ(s)−a∥2]
相较于offline,损失函数增加对BC权重因子 λ \lambda λ。该数值是动态减少的,实验设置为每5k steps, 减少10%。对Q价值的更新则是类似于TD3,使用两个target网络以及延时更新。
L ( ϕ ) = E ( s , a ) ∼ B [ ( y ˉ − Q ϕ ( s , a ) ) 2 ] where y ˉ = r + min i = 1 , 2 Q ϕ i ˉ ( s , ′ a ′ ∼ π θ ˉ ) . \begin{aligned}L(\phi)&=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\left(\bar{y}-Q_{\boldsymbol{\phi}}(\boldsymbol{s},\boldsymbol{a})\right)^2\right]\\\\\text{where }\bar{y}&=r+\min_{i=1,2}Q_{\bar{\boldsymbol{\phi}_i}}(\boldsymbol{s},'\boldsymbol{a}'\sim\pi_{\bar{\boldsymbol{\theta}}}).\end{aligned} L(ϕ)where yˉ=E(s,a)∼B[(yˉ−Qϕ(s,a))2]=r+i=1,2minQϕiˉ(s,′a′∼πθˉ).
伪代码如下:
Summary
有个疑问,online阶段对策略进行更新时,采样的数据(s,a)是来自replaybuffer B \mathcal{B} B。 B \mathcal{B} B包含在线阶段真实交互数据以及离线数据。如果(s,a)是OOD或者质量差数据,那么此时BC项应该尽可能地不要发挥作用。简单的调整 λ \lambda λ恐怕效果不够。可以探索添在BC项再加一个指示函数自适应地判断,“异常数据”直接截断为0.