Stanford斯坦福 CS 224R: 深度强化学习 (5)

离线强化学习:第一部分

强化学习(RL)旨在让智能体通过与环境交互来学习最优策略,从而最大化累积奖励。传统的RL训练都是在线(online)进行的,即智能体在训练过程中不断与环境交互,实时生成新的状态-动作数据,并基于新数据来更新策略。这种在线学习虽然简单直观,但也存在一些局限性:

  • 在线交互的样本效率较低,许多采集到的数据未被充分利用
  • 对于一些高风险场景(如自动驾驶),在线探索可能会带来安全隐患
  • 一些领域积累了大量历史数据,在线学习无法直接利用这些数据

为了克服在线学习的局限,研究者们提出了离线强化学习(Offline RL)范式。本章我们将系统地介绍离线RL的动机、挑战与核心方法,探讨如何从静态数据集出发来训练最优策略。

1. 为什么需要离线强化学习?

传统RL算法都是在线学习的,其训练流程可以概括为:

  1. 用当前策略 π θ \pi_{\theta} πθ 采集一批数据 D \mathcal{D} D
  2. 基于新数据 D \mathcal{D} D 或累积数据 D 1 : t \mathcal{D}_{1:t} D1:t 来更新策略
  3. 重复步骤1-2,直到策略收敛

这里的数据采集可以是on-policy的(即直接用 π θ \pi_{\theta} πθ 采集),也可以是off-policy的(用其他策略如 ϵ \epsilon ϵ-greedy采集)。但无论哪种方式,智能体都要在训练过程中持续与环境交互。

离线RL则打破了这一限制,其目标是仅利用一个静态数据集 D = { ( s , a , r , s ′ ) i } i = 1 N \mathcal{D}=\{(s,a,r,s')_i\}_{i=1}^N D={(s,a,r,s)i}i=1N 来学习最优策略 π ∗ \pi^* π。形式化地,这个数据集可以看作从某个未知的行为策略(behavior policy) π β \pi_{\beta} πβ 采样得到:

s ∼ d π β ( ⋅ ) a ∼ π β ( ⋅ ∣ s ) s ′ ∼ P ( ⋅ ∣ s , a ) r = r ( s , a ) \begin{aligned} s \sim d^{\pi_{\beta}}(\cdot) \\ a \sim \pi_{\beta}(\cdot \mid s) \\ s' \sim \mathcal{P}(\cdot \mid s,a) \\ r = r(s,a) \end{aligned} sdπβ()aπβ(s)sP(s,a)r=r(s,a)

d π β d^{\pi_{\beta}} dπβ 表示 π β \pi_{\beta} πβ 诱导的状态分布。值得注意的是, π β \pi_{\beta} πβ 可以是一组策略的混合,即数据集 D \mathcal{D} D 可能来自多个行为策略。

因此,离线RL的优化目标可以写为:

max ⁡ θ ∑ t E s t ∼ d π θ ( ⋅ ) , a t ∼ π θ ( ⋅ ∣ s t ) [ r ( s t , a t ) ] \max_{\theta} \sum_{t} \mathbb{E}_{s_t \sim d^{\pi_{\theta}}(\cdot),a_t \sim \pi_{\theta}(\cdot \mid s_t)} [r(s_t,a_t)] θmaxtEstdπθ(),atπθ(st)[r(st,at)]

即最大化新策略 π θ \pi_{\theta} πθ 在其诱导的状态分布上的期望累积奖励。

相比在线RL,离线RL有如下优势:

  • 可以充分利用人工采集的、历史系统产生的离线数据,提高样本利用率
  • 避免了在线交互的安全风险,适用于高风险场景的策略学习
  • 可以复用之前项目积累的数据,减少重复采集的成本

当然,我们也可以将离线学习和在线学习相结合,用离线数据进行预训练,再用在线数据进行finetune。但本章我们将重点放在纯离线RL设定上。

2. 离线RL能否直接套用off-policy算法?

基于离线数据集学习策略,这听起来与off-policy RL很相似。那么我们是否可以直接将off-policy算法应用到离线场景中呢?

以经典的Q-learning为例,它的目标是最小化temporal difference (TD) error:

min ⁡ Q E ( s , a , r , s ′ ) ∼ D [ ( Q ( s , a ) − ( r + γ max ⁡ a ′ Q ( s ′ , a ′ ) ) ) 2 ] \min_{Q} \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \left( Q(s,a) - \left( r + \gamma \max_{a'} Q(s',a') \right) \right)^2 \right] QminE(s,a,r,s)D[(Q(s,a)(r+γamaxQ(s,a)))2]

其中 Q ( s , a ) Q(s,a) Q(s,a) 为状态-动作值函数。学到Q函数后,策略可以通过 ϵ \epsilon ϵ-greedy或直接选取最大Q值的动作来生成。

然而,如果我们直接用一个静态数据集(如由一个次优策略采集)来训练Q-learning,会发生什么?

在训练初期Q网络随机初始化时,对于某个状态 s ′ s' s,网络输出的各个动作值 Q ( s ′ , ⋅ ) Q(s',\cdot) Q(s,) 可能与实际值相差较大。假设某个动作 a ′ a' a 在数据集 D \mathcal{D} D 中从未出现过,但Q网络给出了一个非常乐观的估计 Q ( s ′ , a ′ ) Q(s',a') Q(s,a)。在Q-learning更新时,这个乐观估计会通过 max ⁡ \max max 操作传播到 Q ( s , a ) Q(s,a) Q(s,a),导致 Q ( s , a ) Q(s,a) Q(s,a) 被高估。

随着训练的进行,Q值的高估会不断放大,最终导致离线训练策略 π θ \pi_{\theta} πθ 过度偏离行为策略 π β \pi_{\beta} πβ。这种偏离一方面源自对未见过动作的乐观估计,另一方面也反映了分布漂移问题: π θ \pi_{\theta} πθ 访问了 π β \pi_{\beta} πβ 未覆盖的状态空间。当 π θ \pi_{\theta} πθ 部署到实际环境中时,这些偏差会导致严重的性能降级,甚至灾难性后果。

因此,如何缓解离线RL中的Q值过估计,是我们需要重点解决的问题。这也是离线RL区别于off-policy RL的关键。接下来我们将介绍两大类缓解过估计的方法。

3. 数据约束方法

既然Q值过估计源于对未见过动作的错误泛化,一个很自然的想法是:能否约束 π θ \pi_{\theta} πθ 不要偏离 π β \pi_{\beta} πβ 太远?直观地说,如果新策略的动作分布与数据集的动作分布接近,就可以避免查询到未见过的(可能高估的)Q值。

形式化地,带行为约束的离线RL目标可以写为:

max ⁡ θ ∑ t E s t ∼ d π θ ( ⋅ ) , a t ∼ π θ ( ⋅ ∣ s t ) [ r ( s t , a t ) ] s.t. D ( π θ , π β ) ≤ ϵ \begin{aligned} \max_{\theta} & \sum_{t} \mathbb{E}_{s_t \sim d^{\pi_{\theta}}(\cdot),a_t \sim \pi_{\theta}(\cdot \mid s_t)} [r(s_t,a_t)] \\ \text{s.t.} & D(\pi_{\theta}, \pi_{\beta}) \leq \epsilon \end{aligned} θmaxs.t.tEstdπθ(),atπθ(st)[r(st,at)]D(πθ,πβ)ϵ

其中 D ( ⋅ , ⋅ ) D(\cdot,\cdot) D(,) 为某种分布差异度量, ϵ \epsilon ϵ 为预设的容忍度。这样一来,我们限制了 π θ \pi_{\theta} πθ π β \pi_{\beta} πβ 的偏离程度,避免了对OOD(out-of-distribution)动作的过度乐观估计。一些常用的分布约束形式包括:

  • 支撑集约束: π θ ( a ∣ s ) > 0 \pi_{\theta}(a \mid s) > 0 πθ(as)>0 当且仅当 π β ( a ∣ s ) ≥ δ \pi_{\beta}(a \mid s) \geq \delta πβ(as)δ。即新策略只能选择在数据集中出现过的动作。
  • KL散度约束: D K L ( π θ ∥ π β ) ≤ ϵ D_{KL}(\pi_{\theta} \| \pi_{\beta}) \leq \epsilon DKL(πθπβ)ϵ。即限制新旧策略的KL散度在 ϵ \epsilon ϵ 以内。

前者比较符合我们的直观要求,但在实践中难以准确实现。后者便于优化求解,但约束相对宽泛。

不过这里还有一个问题:上述约束中的 π β \pi_{\beta} πβ 往往是未知的,因为离线数据集可能来自多个行为策略的混合。为了实现约束,我们需要先从数据中学习一个 π β \pi_{\beta} πβ 的逼近 π ^ β \hat{\pi}_{\beta} π^β。最简单的做法是行为克隆(behavior cloning),即监督学习动作概率:

π ^ β = arg ⁡ max ⁡ π E s , a ∼ D [ log ⁡ π ( a ∣ s ) ] \hat{\pi}_{\beta} = \arg\max_{\pi} \mathbb{E}_{s,a \sim \mathcal{D}} [\log \pi(a \mid s)] π^β=argπmaxEs,aD[logπ(as)]

这等价于对 π β \pi_{\beta} πβ 的最大似然估计。学到 π ^ β \hat{\pi}_{\beta} π^β 后,我们就可以基于它来施加行为约束。那么如何在优化过程中高效实现这些约束呢?下面介绍两种常见做法。

第一种做法是直接修改策略优化目标,将约束项合并进去。以KL散度约束为例:

max ⁡ θ E s ∼ D , a ∼ π θ [ Q ( s , a ) ] − α D K L ( π θ ( ⋅ ∣ s ) ∥ π ^ β ( ⋅ ∣ s ) ) \max_{\theta} \mathbb{E}_{s \sim \mathcal{D},a \sim \pi_{\theta}} [Q(s,a)] - \alpha D_{KL}(\pi_{\theta}(· \mid s) \| \hat{\pi}_{\beta}(· \mid s)) θmaxEsD,aπθ[Q(s,a)]αDKL(πθ(s)π^β(s))

这里 α \alpha α 为拉格朗日乘子,控制约束的强度。这种无约束化的优化目标在实践中更易求解。对于高斯、分类等常见的策略分布族,KL散度往往有简洁的解析形式。

第二种做法是修改策略的奖励函数,引入一项鼓励与 π β \pi_{\beta} πβ 接近的附加奖励:

r ~ ( s , a ) = r ( s , a ) + α log ⁡ π ^ β ( a ∣ s ) \tilde{r}(s,a) = r(s,a) + \alpha \log \hat{\pi}_{\beta}(a \mid s) r~(s,a)=r(s,a)+αlogπ^β(as)

这相当于把约束的误差项分摊到每一步的奖励中,让策略不仅追求高累积奖励,也避免选择 π β \pi_{\beta} πβ 很少访问的动作。

事实上,这两种做法在数学上是等价的。感兴趣的读者可以参考[Wu et al., 2019]了解更多实现细节。

4. 保守估计方法

数据约束方法本质上是通过约束策略行为来规避Q值过估计。与之互补地,我们也可以直接针对值函数,在估计过程中鼓励保守、惩罚过于乐观的预测。这类方法的优点在于不需要显式地建模行为策略 π β \pi_{\beta} πβ

其中一个代表性算法是保守Q学习(Conservative Q-Learning, CQL)[Kumar et al., 2020]。回顾标准的Q学习目标:

min ⁡ Q E ( s , a , r , s ′ ) ∼ D [ ( Q ( s , a ) − ( r + γ E π θ [ Q ( s ′ , ⋅ ) ] ) ) 2 ] \min_{Q} \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \left( Q(s,a) - \left( r + \gamma \mathbb{E}_{\pi_{\theta}}[Q(s',\cdot)] \right) \right)^2 \right] QminE(s,a,r,s)D[(Q(s,a)(r+γEπθ[Q(s,)]))2]

CQL在此基础上引入了一项保守正则项:

min ⁡ Q E ( s , a , r , s ′ ) ∼ D [ ( Q ( s , a ) − ( r + γ E π θ [ Q ( s ′ , ⋅ ) ] ) ) 2 ] + α E s ∼ D , a ∼ μ ( ⋅ ∣ s ) [ Q ( s , a ) ] − α E ( s , a ) ∼ D [ Q ( s , a ) ] \begin{aligned} \min_{Q} & \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \left( Q(s,a) - \left( r + \gamma \mathbb{E}_{\pi_{\theta}}[Q(s',\cdot)] \right) \right)^2 \right] \\ & + \alpha \mathbb{E}_{s \sim \mathcal{D},a \sim \mu(\cdot \mid s)} [Q(s,a)] - \alpha \mathbb{E}_{(s,a) \sim \mathcal{D}} [Q(s,a)] \end{aligned} QminE(s,a,r,s)D[(Q(s,a)(r+γEπθ[Q(s,)]))2]+αEsD,aμ(s)[Q(s,a)]αE(s,a)D[Q(s,a)]

其中 μ ( ⋅ ∣ s ) \mu(\cdot \mid s) μ(s) 为任意一个先验行为分布(可学习), α \alpha α 为权重系数。

这个正则项由两部分组成:

  • E s ∼ D , a ∼ μ ( ⋅ ∣ s ) [ Q ( s , a ) ] \mathbb{E}_{s \sim \mathcal{D},a \sim \mu(\cdot \mid s)} [Q(s,a)] EsD,aμ(s)[Q(s,a)] 对所有状态-动作对的Q值求期望,通过最大化 μ \mu μ 来放大大的Q值
  • E ( s , a ) ∼ D [ Q ( s , a ) ] \mathbb{E}_{(s,a) \sim \mathcal{D}} [Q(s,a)] E(s,a)D[Q(s,a)] 只对数据集 D \mathcal{D} D 中observerd的状态-动作对求期望,作为对比

第一项相当于鼓励 μ \mu μ 去寻找预测值最大的(可能高估的)状态-动作对,而第二项则把这些对的预测值拉回到实际观测值。两项的差形成了一个惩罚,抑制了Q网络在数据集外对动作值的过高估计。

可以证明,当 α \alpha α 足够大时,用 L C Q L L_{CQL} LCQL 学到的Q函数 Q ^ π \hat{Q}^{\pi} Q^π 可以下界真实Q函数 Q π Q^{\pi} Qπ:

Q ^ π ( s , a ) ≤ Q π ( s , a ) , ∀ ( s , a ) \hat{Q}^{\pi}(s,a) \leq Q^{\pi}(s,a), \forall (s,a) Q^π(s,a)Qπ(s,a),(s,a)

因此CQL确保了对Q值的保守估计,避免了过度乐观。

根据 Q ^ π \hat{Q}^{\pi} Q^π,我们可以得到一个下界最优策略 π C Q L \pi_{CQL} πCQL:

π C Q L ( ⋅ ∣ s ) = arg ⁡ max ⁡ π E a ∼ π ( ⋅ ∣ s ) [ Q ^ π ( s , a ) ] \pi_{CQL}(\cdot \mid s) = \arg\max_{\pi} \mathbb{E}_{a \sim \pi(\cdot \mid s)} [\hat{Q}^{\pi}(s,a)] πCQL(s)=argπmaxEaπ(s)[Q^π(s,a)]

当动作空间离散时,这个问题的解为确定性策略:

π C Q L ( a ∣ s ) = { 1 if  a = arg ⁡ max ⁡ a ′ Q ^ π ( s , a ′ ) 0 otherwise \pi_{CQL}(a \mid s) = \begin{cases} 1 & \text{if } a = \arg\max_{a'} \hat{Q}^{\pi}(s,a') \\ 0 & \text{otherwise} \end{cases} πCQL(as)={10if a=argmaxaQ^π(s,a)otherwise

当动作空间连续时,我们可以通过梯度上升来逼近最优策略:

θ ← θ + η ∇ θ E s ∼ D , a ∼ π θ ( ⋅ ∣ s ) [ Q ^ π ( s , a ) ] \theta \leftarrow \theta + \eta \nabla_{\theta} \mathbb{E}_{s \sim \mathcal{D},a \sim \pi_{\theta}(\cdot \mid s)} [\hat{Q}^{\pi}(s,a)] θθ+ηθEsD,aπθ(s)[Q^π(s,a)]

其中 η \eta η 为学习率。

完整的CQL算法流程如下:

  1. L C Q L L_{CQL} LCQL 和数据集 D \mathcal{D} D 来训练Q网络 Q ^ π \hat{Q}^{\pi} Q^π
  2. 根据 Q ^ π \hat{Q}^{\pi} Q^π 来更新策略 π θ \pi_{\theta} πθ
  3. 重复1-2步骤,直到 Q ^ π \hat{Q}^{\pi} Q^π π θ \pi_{\theta} πθ 都收敛

实践中,我们还需要给先验分布 μ ( ⋅ ∣ s ) \mu(\cdot \mid s) μ(s) 赋予一个具体的形式。一个常见的选择是指数型分布族:

μ ( a ∣ s ) ∝ exp ⁡ ( Q ^ π ( s , a ) ) \mu(a \mid s) \propto \exp(\hat{Q}^{\pi}(s,a)) μ(as)exp(Q^π(s,a))

此时第一个正则项可以化简为:

E s ∼ D , a ∼ μ ( ⋅ ∣ s ) [ Q ( s , a ) ] = E s ∼ D [ log ⁡ ∑ a exp ⁡ ( Q ^ π ( s , a ) ) ] \mathbb{E}_{s \sim \mathcal{D},a \sim \mu(\cdot \mid s)} [Q(s,a)] = \mathbb{E}_{s \sim \mathcal{D}} [\log \sum_{a} \exp(\hat{Q}^{\pi}(s,a))] EsD,aμ(s)[Q(s,a)]=EsD[logaexp(Q^π(s,a))]

这样我们就无需显式地优化 μ \mu μ,而是通过对数求和运算来放大Q值的差异。

CQL从惩罚乐观估计的角度出发,巧妙地规避了分布漂移问题,在许多离线RL基准上取得了sota的性能。其优势在于:

  • 算法简单,易于实现,可以与任意Q学习变体结合
  • 超参数( α \alpha α)较少,调参负担小
  • 通过下界Q值来保证策略改进,而不需要显式地约束策略行为
  • 在低数据质量场景下表现稳健

当然,CQL也非万能。其主要局限包括:

  • 在复杂环境中,估计的Q值下界可能过于保守,限制了策略改进的空间
  • 即便估计了状态-动作对的真实Q值,在离线数据稀疏时策略的泛化性能仍然堪忧

未来还需要在这些方面进一步改进CQL框架。但它为离线值估计提供了一种新的视角,为后续算法的发展奠定了基础。

除了CQL,研究者还提出了其他一些惩罚乐观估计的方法。感兴趣的读者可以进一步参考:

  • Conservative Offline Policy Evaluation (CopE) [Voloshin et al., 2021]
  • Critic Regularized Regression (CRR) [Wang et al., 2020]
  • Advantage-Weighted Regression (AWR) [Peng et al., 2019]

5. 基于模型的离线RL

前面介绍的CQL通过正则化Q值来避免过估计。另一种思路是利用环境模型,用模型生成的虚拟数据来辅助值估计。

具体来说,我们先从离线数据集 D \mathcal{D} D 中学习一个环境模型 p ^ ( s ′ ∣ s , a ) \hat{p}(s' \mid s,a) p^(ss,a)。然后用 p ^ \hat{p} p^ 从真实数据中采样的状态出发,滚动生成多条虚拟轨迹 { ( s , a ) i } \{(s,a)_i\} {(s,a)i}。接下来把这些轨迹添加到数据集 D \mathcal{D} D 中,扩充训练集。最后把扩充后的数据集用于Q网络的训练。

这一过程可以总结为:

  • 模型训练: p ^ = arg ⁡ max ⁡ p E D [ log ⁡ p ( s ′ ∣ s , a ) ] \hat{p} = \arg\max_{p} \mathbb{E}_{\mathcal{D}} [\log p(s' \mid s,a)] p^=argmaxpED[logp(ss,a)]
  • 虚拟数据生成: D ~ = { ( s , a ) i ∣ s ∼ D , a ∼ π β ( ⋅ ∣ s ) , s ′ ∼ p ^ ( ⋅ ∣ s , a ) } \tilde{\mathcal{D}} = \{(s,a)_i \mid s \sim \mathcal{D}, a \sim \pi_{\beta}(\cdot \mid s), s' \sim \hat{p}(\cdot \mid s,a)\} D~={(s,a)isD,aπβ(s),sp^(s,a)}
  • 数据扩充: D a u g = D ∪ D ~ \mathcal{D}_{aug} = \mathcal{D} \cup \tilde{\mathcal{D}} Daug=DD~
  • Q值训练: 用 D a u g \mathcal{D}_{aug} Daug 训练Q网络

直观地说,由模型生成的虚拟数据覆盖了真实数据未覆盖的状态-动作空间。如果模型预测得足够准,就能帮助Q网络减少外推误差。此外,这些虚拟数据服从行为策略 π β \pi_{\beta} πβ 的分布,因此也起到了隐式约束 π θ \pi_{\theta} πθ 的作用。

当然,这一切的前提是学到一个准确可靠的环境模型 p ^ \hat{p} p^。一个差的模型反而会引入有偏的估计,误导Q网络的训练。因此基于模型的离线RL对模型质量要求很高,这在复杂环境中往往难以满足。同时即便有了一个完美模型,我们仍然面临探索不足的问题:如果某些重要状态在真实数据中没有被访问到,仅靠模型想象也无济于事。

尽管如此,将规划与学习相结合仍是一个有前景的方向。环境模型可以充分利用静态数据中蕴含的先验知识,加速值估计与策略搜索的过程。未来还需进一步探索更鲁棒的模型学习算法,以及模型不确定性评估与主动采样等技术。

6. 离线RL:一种数据驱动的范式

通过以上讨论,我们看到离线RL的核心挑战在于如何缓解利用静态数据训练时的过估计问题。这需要我们在策略约束和值估计两个层面进行针对性的优化。

数据约束方法从策略出发,通过模仿行为策略来规避分布漂移。而保守估计方法则从值函数出发,通过惩罚过高的外推值来纠偏。二者分别从横向和纵向阻断了Q值过估计的来源。我们还可以用模型生成的数据来加强估计,前提是要有一个高置信的环境模型。

值得一提的是,离线RL并非要完全取代传统的在线学习范式,而是作为一种互补的视角,从数据的角度重新审视强化学习问题。在许多在线交互受限的场景下,离线RL让我们能最大限度地利用先验知识,用更少的样本学到更好的策略。同时离线数据也为在线学习提供了一个有价值的预训练,大大缩短了实际部署中的探索成本。

展望未来,离线RL有望成为一种更为普适的范式,为RL在更广泛领域的应用扫清障碍。这需要我们在算法、工程、应用等多个层面协同发力:

  • 算法层面,需要进一步完善离线值估计、策略优化、模型学习等核心组件,提高其效率与鲁棒性。
  • 工程层面,需要开发成熟的离线RL代码库与工具链,简化算法实现与部署流程。
  • 应用层面,需要在更多领域积累有价值的离线数据集,验证算法性能,并反哺算法迭代。

总之,离线RL代表了一种数据驱动的智能优化思路。它凝结了统计学习和最优控制的精华,为AI在开放环境中的自主决策铺平了道路。让我们拭目以待这一领域的蓬勃发展,期待RL在人类社会的更多应用!

7. 代码实战

下面我们通过几个简单的例子来演示离线RL算法的实现。我们选择了OpenAI Gym中的Pendulum环境作为测试平台。该环境的状态空间为3维,动作空间为1维连续。我们的目标是控制摆球快速、平稳地悬停在竖直位置。

首先,我们定义一些辅助函数来采集离线数据。collect_demo使用专家策略(此处为PID控制器)来采集示教轨迹。collect_random则使用随机策略采集数据,模拟低质量的离线数据集。

import gym
import numpy as np
import torch
from typing import Dict, List, Tupledef collect_demo(env: gym.Env, buffer_size: int) -> List[Tuple[np.ndarray, float, np.ndarray, bool]]:"""Use PID controller to collect expert trajectories"""buffer = []state = env.reset()for _ in range(buffer_size):setpoint = np.array([1.0, 0.0, 0.0])pid = PIDController(setpoint, 0.5, 1.3, 0.1, 1.0, 0.1, 0.1)while True:action = pid(state)next_state, reward, done, _ = env.step(action)buffer.append((state, action, reward, next_state, done))state = next_stateif done:state = env.reset()breakreturn bufferdef collect_random(env: gym.Env, buffer_size: int) -> List[Tuple[np.ndarray, float, np.ndarray, bool]]:"""Collect trajectories using uniform random policy"""  buffer = []state = env.reset()for _ in range(buffer_size):action = env.action_space.sample()next_state, reward, done, _ = env.step(action)buffer.append((state, action, reward, next_state, done))state = next_stateif done:state = env.reset() return buffer

接下来,我们定义CQL算法的核心组件:Q网络和策略网络。Q网络使用一个3层MLP来拟合状态-动作值函数。策略网络也是一个3层MLP,输出动作的均值和方差,服从高斯分布。

class QNetwork(torch.nn.Module):def __init__(self, state_dim: int, action_dim: int, hidden_dim: int):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(state_dim + action_dim, hidden_dim),torch.nn.ReLU(),torch.nn.Linear(hidden_dim, hidden_dim),   torch.nn.ReLU(),torch.nn.Linear(hidden_dim, 1))def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:x = torch.cat([state, action], dim=1)q_value = self.layers(x)return torch.squeeze(q_value, -1)class Policy(torch.nn.Module):def __init__(self, state_dim: int, action_dim: int, hidden_dim: int):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(state_dim, hidden_dim),torch.nn.ReLU(),torch.nn.Linear(hidden_dim, hidden_dim),torch.nn.ReLU()  )self.mean_layer = torch.nn.Linear(hidden_dim, action_dim)self.log_std_layer = torch.nn.Linear(hidden_dim, action_dim)def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:x = self.layers(state)mean = self.mean_layer(x)log_std = self.log_std_layer(x)log_std = torch.clamp(log_std, -20, 2)  # avoid numerical issuesreturn mean, log_stddef sample_action(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:mean, log_std = self(state)std = torch.exp(log_std)  dist = torch.distributions.Normal(mean, std)action = dist.rsample()log_prob = dist.log_prob(action).sum(dim=-1)return action, log_prob

现在我们可以实现完整的CQL算法了。我们将离线数据集 D \mathcal{D} D 随机划分为训练集和验证集。在每个训练步中,从验证集采样一批状态作为 s ∼ D s \sim \mathcal{D} sD,根据当前Q网络计算正则项 E s ∼ D , a ∼ μ ( ⋅ ∣ s ) [ Q ( s , a ) ] \mathbb{E}_{s \sim \mathcal{D},a \sim \mu(\cdot \mid s)} [Q(s,a)] EsD,aμ(s)[Q(s,a)]。然后从训练集采样一批转移 ( s , a , r , s ′ ) (s,a,r,s') (s,a,r,s) 来计算TD误差。最终基于TD误差、正则项和在 D \mathcal{D} D 上的期望值 E ( s , a ) ∼ D [ Q ( s , a ) ] \mathbb{E}_{(s,a) \sim \mathcal{D}} [Q(s,a)] E(s,a)D[Q(s,a)] 来优化Q网络参数。策略网络则基于更新后的Q网络用 E s ∼ D , a ∼ π θ ( ⋅ ∣ s ) [ Q ^ π ( s , a ) ] \mathbb{E}_{s \sim \mathcal{D},a \sim \pi_{\theta}(\cdot \mid s)} [\hat{Q}^{\pi}(s,a)] EsD,aπθ(s)[Q^π(s,a)] 来优化。

def cql(env: gym.Env, buffer: List[Tuple[np.ndarray, float, np.ndarray, bool]],hidden_dim: int,learning_rate: float,alpha: float,gamma: float,tau: float,batch_size: int,num_epochs: int):"""Train soft actor-critic algorithm with CQL regularization.Args:env: gym training environmentbuffer: offline datasethidden_dim: hidden dimension of MLP  learning_rate: learning rate for Adam optimizeralpha: weight for CQL regularizergamma: discount factortau: target network update ratebatch_size: batch size for sampling num_epochs: number of training epochs"""state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]q1 = QNetwork(state_dim, action_dim, hidden_dim)q2 = QNetwork(state_dim, action_dim, hidden_dim)target_q1 = QNetwork(state_dim, action_dim, hidden_dim)target_q2 = QNetwork(state_dim, action_dim, hidden_dim)target_q1.load_state_dict(q1.state_dict())target_q2.load_state_dict(q2.state_dict())pi = Policy(state_dim, action_dim, hidden_dim)q1_optimizer = torch.optim.Adam(q1.parameters(), lr=learning_rate)q2_optimizer = torch.optim.Adam(q2.parameters(), lr=learning_rate)pi_optimizer = torch.optim.Adam(pi.parameters(), lr=learning_rate)# split buffer into train and val sets  val_size = int(0.2 * len(buffer))train_buffer, val_buffer = buffer[:-val_size], buffer[-val_size:] for epoch in range(num_epochs):# sample a batch of transitions from train set state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*random.sample(train_buffer, batch_size))state_batch = torch.tensor(state_batch, dtype=torch.float32)action_batch = torch.tensor(action_batch, dtype=torch.float32).unsqueeze(1)  reward_batch = torch.tensor(reward_batch, dtype=torch.float32).unsqueeze(1)next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32)done_batch = torch.tensor(done_batch, dtype=torch.float32).unsqueeze(1)# compute conservative q-valueswith torch.no_grad():next_action, next_log_pi = pi.sample_action(next_state_batch)q1_next = target_q1(next_state_batch, next_action)q2_next = target_q2(next_state_batch, next_action)min_q_next = torch.min(q1_next, q2_next) - alpha * next_log_piq_target = reward_batch + gamma * (1 - done_batch) * min_q_nextq1_pred = q1(state_batch, action_batch)  q1_loss = 0.5 * torch.mean((q1_pred - q_target)**2)q2_pred = q2(state_batch, action_batch)  q2_loss = 0.5 * torch.mean((q2_pred - q_target)**2)# compute CQL regularizercql_state_batch = torch.tensor(random.sample(val_buffer, batch_size), dtype=torch.float32) cql_actions = torch.linspace(-2.0, 2.0, steps=51).unsqueeze(0).repeat(batch_size, 1)  q1_values = q1(cql_state_batch, cql_actions)q2_values = q2(cql_state_batch, cql_actions)cql_values = torch.logsumexp(torch.cat([q1_values, q2_values], dim=1), dim=1)q1_data_values = q1(state_batch, action_batch)q2_data_values = q2(state_batch, action_batch)cql1_loss = torch.mean(cql_values - q1_data_values)cql2_loss = torch.mean(cql_values - q2_data_values)q1_loss += alpha * cql1_loss q2_loss += alpha * cql2_lossq1_optimizer.zero_grad()q1_loss.backward()q1_optimizer.step()q2_optimizer.zero_grad()q2_loss.backward()q2_optimizer.step()# update policypi_actions, pi_log_prob = pi.sample_action(state_batch)q1_pi = q1(state_batch, pi_actions)  q2_pi = q2(state_batch, pi_actions)min_q_pi = torch.min(q1_pi, q2_pi)pi_loss = torch.mean(alpha * pi_log_prob - min_q_pi)pi_optimizer.zero_grad()pi_loss.backward()pi_optimizer.step()# update target networksfor target_param, param in zip(target_q1.parameters(), q1.parameters()):target_param.data.copy_(tau * param + (1 - tau) * target_param)  for target_param, param in zip(target_q2.parameters(), q2.parameters()):target_param.data.copy_(tau * param + (1 - tau) * target_param)return pi  # return the trained policy

现在让我们测试一下CQL在Pendulum环境中的表现:

env = gym.make('Pendulum-v1')expert_buffer = collect_demo(env, buffer_size=5000)
random_buffer = collect_random(env, buffer_size=5000)# combine expert and random data to create the offline dataset
buffer = expert_buffer + random_buffercql_policy = cql(env, buffer, hidden_dim=256, learning_rate=3e-4, alpha=0.5, gamma=0.99, tau=0.005,batch_size=256, num_epochs=100)# evaluate the trained policy
for _ in range(10):state = env.reset()done = Falsetotal_reward = 0while not done:action, _ = cql_policy.sample_action(torch.from_numpy(state))next_state, reward, done, _ = env.step(action.detach().numpy())total_reward += rewardstate = next_stateprint(f'Total reward: {total_reward:.2f}')

运行这段代码,你会看到CQL学到的策略能以较高的累积奖励控制Pendulum。我们可以通过调节混合数据集中的专家数据比例,来考察CQL在不同离线数据质量下的表现。你还可以尝试修改网络结构、超参数,甚至环境本身,来深入理解CQL算法的特点。

8. 总结与展望

本章我们系统地介绍了离线强化学习的背景、挑战与代表性算法。与传统RL依赖在线交互不同,离线RL旨在从静态数据集中学习最优策略。这为RL在高风险、数据稀缺场景的应用扫清了障碍。

我们首先分析了Q值过估计这一离线RL的核心难题,即如何缓解策略优化时的分布漂移。针对这一问题,研究者分别从数据约束和保守估计两个角度提出了解决方案。前者通过模仿行为策略来限制策略搜索空间,后者则通过惩罚过高的Q值来纠正乐观估计。我们分别以BCQ和CQL为例,详细讲解了这两类方法的动机和实现。

此外,我们也探讨了基于模型的离线RL,即先从数据中学习一个环境模型,再用模型产生的虚拟轨迹来辅助Q函数训练。尽管对模型质量和探索效率提出了更高要求,但这一思路为离线数据的充分利用提供了新的视角。

纵观全章,我们强调了离线RL作为一种数据驱动范式的优势与潜力。通过从人类专家、历史系统等渠道收集知识,它让我们得以快速、低成本地构建高性能智能体。未来,随着离线RL算法的不断完善,以及应用领域的持续拓展,相信它必将在自动驾驶、智能家居、工业控制等诸多领域发挥重要作用。

与此同时,我们也要看到离线RL仍面临不少理论和实践挑战:

  • 在高维、稀疏数据上如何更高效地进行离线值估计与策略优化?
  • 如何评估和提高离线数据集的质量,尤其是在reward含噪、分布不平衡时?
  • 如何设计更鲁棒、更可解释的模型学习算法,降低泛化风险?
  • 如何权衡离线训练和在线微调,实现策略的安全部署与持续进化?

这需要RL、统计学、因果推断、泛化理论等多个领域的交叉创新。让我们携手探索,共同推动离线RL的发展,用人工智能点亮未来!

参考文献

  1. Levine S, Kumar A, Tucker G, et al. Offline reinforcement learning: Tutorial, review, and perspectives on open problems. arXiv preprint arXiv:2005.01643, 2020.

  2. Fujimoto S, Meger D, Precup D. Off-policy deep reinforcement learning without exploration. In International Conference on Machine Learning, 2019.

  3. Kumar A, Zhou A, Tucker G, et al. Conservative q-learning for offline reinforcement learning. arXiv preprint arXiv:2006.04779, 2020.

  4. Wu Y, Tucker G, Nachum O. Behavior regularized offline reinforcement learning. arXiv preprint arXiv:1911.11361, 2019.

  5. Kidambi R, Rajeswaran A, Netrapalli P, et al. Morel: Model-based offline reinforcement learning. arXiv preprint arXiv:2005.05951, 2020.

Q&A

我用一个简单的例子来说明离线强化学习的应用场景和优势。

假设我们要开发一个智能推荐系统,为用户推荐感兴趣的商品。传统的推荐算法主要基于用户的历史交互数据(如点击、购买记录)来挖掘用户兴趣,然后匹配相似商品。这种方法虽然简单直观,但也存在一些局限性:

  • 只利用了用户的历史数据,难以适应用户兴趣的变化
  • 无法主动探索用户对新品类、新特性的喜好
  • 容易陷入"马太效应",加剧推荐的同质化

如果我们把推荐问题建模为一个强化学习任务,就可以在一定程度上克服这些局限。具体来说,我们可以把:

  • 用户的个人信息和历史行为作为状态
  • 向用户推荐的商品作为动作
  • 用户的反馈(如点击、购买、评分)作为奖励

这样一来,RL智能体的目标就是学习一个推荐策略,使得累积奖励(用户满意度)最大化。通过与用户的长期交互,智能体可以逐步优化推荐策略,及时捕捉用户兴趣的变化,并主动探索新的可能性。

然而,在实际的推荐系统中,我们很难让智能体与真实用户频繁交互。一方面,过于频繁的推荐会影响用户体验;另一方面,将未经验证的策略直接应用于线上系统是有风险的,可能导致用户流失、收入下降等不良后果。

离线RL为这一问题提供了一个理想的解决方案。我们可以利用系统积累的海量历史交互数据,从中学习用户兴趣分布和行为模式,再利用这些先验知识指导推荐策略的离线训练。具体流程如下:

  1. 收集用户与推荐系统的历史交互日志,包括用户特征、推荐商品、用户反馈等
  2. 对日志数据进行预处理,提取状态、动作、奖励,并划分为训练集和测试集
  3. 用离线RL算法在训练集上学习推荐策略,并在测试集上评估性能
  4. 选择性能最优的策略在线上系统中小流量测试,并根据反馈进一步调优
  5. 逐步扩大优化后策略的应用范围,持续监测和改进

可以看到,离线RL让我们能在不影响线上系统的情况下,充分利用历史数据来优化推荐策略。它不仅降低了试错成本,也提高了策略迭代的效率。一些研究表明,基于离线RL的推荐算法能显著提升用户的点击率、转化率和满意度,为企业创造更多商业价值。

当然,要让离线RL在推荐系统中真正发挥作用,还需要解决一些问题:

  • 如何保证离线数据的覆盖性和质量,尽可能减少分布漂移?
  • 如何权衡探索和利用,在满足用户当前兴趣的同时,发掘用户的潜在需求?
  • 如何引入因果推断,学习推荐行为对用户反馈的因果影响,而不是简单的相关性?
  • 如何在离线评估和在线测试间建立有效映射,减少策略部署的风险?

这需要算法、工程、产品等多团队的通力合作。但离线RL为构建更加智能、高效、以用户为中心的推荐系统铺平了道路。相信未来它将与其他机器学习方法一起,为智能推荐的发展注入新的活力。

总结一下,离线强化学习的主要优势包括:

  • 充分利用历史数据,减少在线试错,降低探索成本
  • 避免对真实环境的干扰,提高优化效率和安全性
  • 自动学习复杂的用户行为模式,捕捉用户兴趣变化
  • 通过主动探索,发现新的推荐可能,提升用户满意度

因此,当我们拥有大量高质量的离线数据,且在线交互代价较高时,离线RL就是一个不错的选择。它为传统的监督学习和在线学习方法提供了有益补充,让我们能更从容、更智能地决策。

希望这个例子能帮你理解离线RL的应用场景和优势。我们只是以推荐系统为例,展示了它的一种可能性。事实上,离线RL还可以用于广告投放、搜索引擎、智能助理等多个领域。随着算法的进步和数据的积累,相信它必将在更多场景中大放异彩,为人类生活带来更多便利。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/15923.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Could not find Chrome This can occur if either】

爬虫练习中遇到的问题 使用puppeteer执行是提示一下错误 Error: Could not find Chrome (ver. 125.0.6422.78). This can occur if either you did not perform an installation before running the script (e.g. npx puppeteer browsers install chrome) oryour cache path…

CLIP 论文的关键内容

CLIP 论文整体架构 该论文总共有 48 页,除去最后的补充材料十页去掉,正文也还有三十多页,其中大部分篇幅都留给了实验和响应的一些分析。 从头开始的话,第一页就是摘要,接下来一页多是引言,接下来的两页就…

常用 CSS 写法

不是最后一个 :not(:last-child)渐变色 background: linear-gradient(270deg, #15aaff 0%, #02396a 100%);文字渐变色 background-image: linear-gradient(to right, #ff7e5f, #feb47b); -webkit-background-clip: text; background-clip: text; color: transparent;

python文件IO基础知识

目录 1.open函数打开文件 2.文件对象读写数据和关闭 3.文本文件和二进制文件的区别 4.编码和解码 读写文本文件时 读写二进制文件时 5.文件指针位置 6.文件缓存区与flush()方法 1.open函数打开文件 使用 open 函数创建一个文件对象,read 方法来读取数据&…

谈谈磁盘的那些操作

磁盘格式化 是指把一张空白的盘划分成一个个小区域并编号,以供计算机存储和读取数据。格式化是一种纯物理操作,是在磁盘的所有数据区上写零的操作过程,同时对硬盘介质做一致性检测,并且标记出不可读和坏的扇区。由于大部分硬盘在…

电子技术学习路线

在小破站上看到大佬李皆宁的技术路线分析,再结合自己这几年的工作。发现的确是这样,跟着大佬的技术路线去学习是会轻松很多,现在想想,这路线其实跟大学四年的学习顺序是很像的。 本期记录学习路线,方便日后查看。 传统…

python 深度图生成点云(方法二)

深度图生成点云 一、介绍1.1 概念1.2 思路1.3 函数讲解二、代码示例三、结果示例接上篇:深度图生成点云(方法1) 一、介绍 1.1 概念 深度图生成点云:根据深度图像(depth image)和相机内参(camera intrinsics)生成点云(PointCloud)。 1.2 思路 点云坐标的计算公式如…

pillow学习7

绘制验证码 from PIL import Image,ImageFilter,ImageFont,ImageDraw import random width100 hight100 imImage.new(RGB,(width,hight),(255,255,255)) drawImageDraw.Draw(im) #获取颜色 def get_color1():return (random.randint(200, 255), random.randint(200, 255), ran…

京东Java社招面试题真题,最新面试题

Java中接口与抽象类的区别是什么? 1、定义方式: 接口是完全抽象的,只能定义抽象方法和常量,不能有实现;而抽象类可以有抽象方法和具体实现的方法,也可以定义成员变量。 2、实现与继承: 一个类…

几种常用的配置文件格式对比分析——ini、json、xml、toml、yaml

配置文件用于存储软件程序的配置信息,以便程序能够根据这些信息进行自定义和调整。常用的配置文件格式包括INI、XML、JSON和YAML。下面对它们进行简单介绍,并分析各自的优缺点。 1. INI 文件格式 简介: INI(Initialization&…

FPGA之tcp/udp

在调试以太网的过程中,考虑了vivado IP配置(管脚、reset等),SDK中PHY芯片的配置(芯片地址、自适应速率配置等),但是,唯独忽略了tcp/udp协议,所以在ping通之后仍无法连接。 所以现在来学习一下tcp与udp的区别 ---- 为什…

经典面试题:进程、线程、协程开销问题,为什么进程切换的开销比线程的大?

上下文切换的过程? 上下文切换是操作系统在将CPU从一个进程切换到另一个进程时所执行的过程。它涉及保存当前执行进程的状态并加载下一个将要执行的进程的状态。下面是上下文切换的详细过程: 保存当前进程的上下文: 当操作系统决定切换到另…

浪潮信息IPF24:AI+时代,创新驱动未来,携手共创智慧新纪元

如今,数字化时代的浪潮席卷全球,人工智能已经成为推动社会进步的重要引擎。浪潮信息IPF24作为行业领先的AI技术盛会,不仅为业界提供了交流合作的平台,更在激发创新活力、拓展发展路径、加速AI技术落地等方面发挥了重要作用。 升级…

OS复习笔记ch6-2

死锁的解决 死锁的预防(打疫苗)死锁的避免(戴口罩)死锁的检测(做核酸) 死锁的预防 前面我们提到了死锁的四个必要条件 防止前三个必要条件,就是间接预防防止最后一个必要条件–循环等待&…

软测刷题-错题1

提高测试效率的方法: 1、不要做无效的测试 2.不要做重复的测试 3.不同测试版本的测试侧重点 4.优化测试顺序 LoadRunner是对服务器进行施压。 在数据库中存在的用户数是指注册用户数。 input标签可以直接使用send_keys实现上传,而非input标签是无法直…

Rust后台管理系统Salvo-admin源码编译

1.克隆salvo-admin后台管理系统源码: https://github.com/lyqgit/salvo-admin.git 2.编译 编译成功 3.创建mysql数据库与执行sql脚本 输入名称ry-vue 执行sql脚本 全部执行上面3个sql 修改数据库用户名与密码: 清理及重新编译 cargo clean cargo build 4.运行并测试 cargo…

Android内存碎片化调优

概念 内存碎片分为两种,一种是内存页中的碎片,被称为内部碎片;另一种是空闲分散的内存页,凑不齐一个组物理地址连续的空闲内存页,就没办法分配了,这些散落的内存页被称为外部碎片。 在Android系统中,内存碎片化是指内存中存在很多小块的空闲内存,这些内存块之间不连续…

使用vue,mybatis,mysql,tomcat,axios实现简单的登录注册功能

目录 第一步环境搭建 后端: 前端: 第二步画流程图 web: service: dao层: 第三步前端代码的实现 这是开始的页面,接下来我们要到router路由下书写#login的路径 路由中的component在我们自己创建的views书写vue文件…

单日收益1000+看了就会的项目,最新灵异短视频项目,简单好上手可放大操作

各位好友,佳哥在此与大伙儿聊聊一项神秘莫测的短视频项目。你或许会想,“又是一个视频创作项目?” 但别急,这个项目与众不同,日入千元不再是梦,而且它的易用性让人惊喜,无论你是初学者还是资深玩…

春秋云境CVE-2018-7422

简介 WordPress Plugin Site Editor LFI 正文 1.进入靶场 2.漏洞利用 /wp-content/plugins/site-editor/editor/extensions/pagebuilder/includes/ajax_shortcode_pattern.php?ajax_path/../../../../../../flag看别人wp做的。不懂怎么弄的,有没有大佬讲一下的