paper
基于HIM的离线RL算法,解决基于序列模型的离线强化学习算法缺乏对序列拼接能力。
Intro
文章提出了ContextFormer,旨在解决决策变换器(Decision Transformer, DT)在轨迹拼接(stitching)能力上的不足。轨迹拼接是离线RL中一个重要的能力,它允许算法通过组合次优的轨迹片段来获得更优的策略。ContextFormer通过集成基于上下文信息的模仿学习(Imitation Learning, IL)和序列建模,模仿有限数量专家轨迹的表示,来实现次优轨迹片段的拼接。实验结果表明,ContextFormer在多模仿学习设置下具有竞争力,并且在与其他DT变体的比较中表现出色。
两个定义
上述两个定义分别给出基于隐变量的条件序列模型建模方式,以及使用专家序列,通过度量经过embedding后的变量距离,使得待优化策略应满足靠近专家策略,远离次优轨迹策略。对于定义二有如下形式化的目标来优化上下文隐变量表征
J z ∗ = min z ∗ , I ϕ E τ ∗ ∼ π ∗ ( τ ) [ ∥ z ∗ − I ϕ ( τ ∗ ) ∥ ] − E τ ^ ∼ π ^ [ ∥ z ∗ − I ϕ ( τ ^ ) ∥ ] , \mathcal{J}_{\mathbf{z}^{*}}=\operatorname*{min}_{\mathbf{z}^{*},I_{\phi}}\mathbb{E}_{\tau^{*}\sim\pi^{*}(\tau)}[\|\mathbf{z}^{*}-I_{\phi}(\tau^{*})\|]\\-\mathbb{E}_{\hat{\tau}\sim\hat{\pi}}[\|\mathbf{z}^{*}-I_{\phi}(\hat{\tau})\|], Jz∗=z∗,IϕminEτ∗∼π∗(τ)[∥z∗−Iϕ(τ∗)∥]−Eτ^∼π^[∥z∗−Iϕ(τ^)∥],
Method
ContextFormer的训练过程包括两个关键模型:Hindsight Information Extractor I ϕ I_{\phi} Iϕ和Contextual Policy。Hindsight Information Extractor使用BERT作为编码器,并采用VQ-VAE(Vector Quantization Variational Autoencoder)损失来训练。Contextual Policy则是一个基于潜在条件的序列模型(DT),通过上下文信息作为目标来优化策略接近专家策略。
根据定义4.1建模序列模型以及 I ϕ I_{\phi} Iϕ,通过监督学习方式优化上下文策略 π z \pi_z πz以及HI extractor。
J π z , I ϕ = E τ ∼ ( π ∗ , π ^ ) [ ∥ π z ( ⋅ ∣ I ϕ ( τ ) , s 0 , a 0 , ⋯ , I ϕ ( τ ) , s t ) − a t ∥ ] , ( 4 ) \mathcal{J}_{\pi_{\mathbf{z}},I_{\phi}}=\mathbb{E}_{\tau\sim(\pi^{*},\hat{\pi})}[\|\pi_{\mathbf{z}}(\cdot|I_{\phi}(\tau),\mathbf{s}_{0},\mathbf{a}_{0},\cdots,I_{\phi}(\tau),\mathbf{s}_{t})-\mathbf{a}_{t}\|], (4) Jπz,Iϕ=Eτ∼(π∗,π^)[∥πz(⋅∣Iϕ(τ),s0,a0,⋯,Iϕ(τ),st)−at∥],(4)
其中 π ^ a n d π ∗ \hat{\pi}\mathrm{~and~}\pi^{*} π^ and π∗分别表示次优策略以及专家策略。同时,基于定义4.2对 I ϕ I_\phi Iϕ以及上下文embedding z ∗ z^* z∗进行优化。
J z ∗ , I ϕ = min z ∗ , I ϕ E τ ^ ∼ π ^ ( τ ) , τ ∗ ∼ π ∗ ( τ ) [ ∥ z ∗ − I ϕ ( τ ∗ ) ∥ − ∣ ∣ z ∗ − I ϕ ( τ ^ ) ∣ ∣ ] ( 5 ) \mathcal{J}_{\mathbf{z}^{*},I_{\phi}}=\min_{\mathbf{z}^{*},I_{\phi}}\mathbb{E}_{\hat{\tau}\sim\hat{\pi}(\tau),\tau^{*}\sim\pi^{*}(\tau)}[\|\mathbf{z}^{*}-I_{\phi}(\tau^{*})\|-||\mathbf{z}^{*}-I_{\phi}(\hat{\tau})||] (5) Jz∗,Iϕ=z∗,IϕminEτ^∼π^(τ),τ∗∼π∗(τ)[∥z∗−Iϕ(τ∗)∥−∣∣z∗−Iϕ(τ^)∣∣](5)
除此外,对于 I ϕ I_\phi Iϕ还需VQ-loss进行优化,三者联合构成了VQ-VAE的训练损失函数。
伪代码
(伪代码Training部分的第二步,VQ-loss应对应公式20)