全面解读PPO算法:结合DeepSpeed Chat实现分析
1. 什么是PPO?
Proximal Policy Optimization (PPO) 是一种基于策略梯度的强化学习方法,属于 Actor-Critic 框架的改进算法。它的目标是稳定地优化策略,避免策略更新过于激进,同时保持训练效率。
在 PPO 中,主要包含两个核心模块:
- Actor:负责学习策略 ( π θ ( a ∣ s ) \pi_\theta(a|s) πθ(a∣s)),即选择某一动作的概率分布。
- Critic:负责估计状态值 ( V ( s ) V(s) V(s)),为策略优化提供参考。
PPO 的设计核心是裁剪(Clipping)策略,确保策略更新幅度受控,从而提高训练的稳定性。
关于PPO的训练流程,可以参考笔者的另一篇博客:RLHF (PPO) 流程详解: Proximal Policy Optimization
2. PPO 的两个核心损失
PPO 的优化目标包括 Actor Loss(策略损失)和 Critic Loss(值函数损失)。我们将结合 DeepSpeed Chat 的实现,详细讲解这两部分的设计。
2.1 Actor Loss:策略裁剪目标
Actor 的目标是优化策略,使得它选择动作的概率与优势函数 ( A t A_t At) 成正比。优势函数表示当前动作的优劣程度,定义为:
A t = Q ( s t , a t ) − V ( s t ) A_t = Q(s_t, a_t) - V(s_t) At=Q(st,at)−V(st)
PPO实际上采用的是GAE版本的优势函数,请参考笔者的另一篇博客:深入解析强化学习中的 Generalized Advantage Estimation (GAE)
为了防止策略更新过快,PPO 对策略更新的目标函数加入裁剪限制:
r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st)
损失函数为:
L actor ( θ ) = E t [ min ( r t ( θ ) ⋅ A t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A t ) ] \mathcal{L}_{\text{actor}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t \right) \right] Lactor(θ)=Et[min(rt(θ)⋅At,clip(rt(θ),1−ϵ,1+ϵ)⋅At)]
其中:
- ( r t ( θ ) r_t(\theta) rt(θ)):新策略与旧策略概率的比值。
- ( ϵ \epsilon ϵ):裁剪范围,通常取 0.1 ~ 0.3。
- ( clip ( ) \text{clip}() clip()):将比值限制在 ( [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1−ϵ,1+ϵ]) 范围内。
PPO 的策略目标通过裁剪限制 ( r t ( θ ) r_t(\theta) rt(θ)),避免策略更新幅度过大。
2.2 Critic Loss:值函数裁剪目标
Critic 的目标是学习一个值函数 ( V ( s ) V(s) V(s)),使其接近实际的回报 ( R t R_t Rt)。通常使用均方误差 (MSE) 作为损失函数:
L critic ( ϕ ) = E t [ ( V ϕ ( s t ) − R t ) 2 ] \mathcal{L}_{\text{critic}}(\phi) = \mathbb{E}_t \left[ \left( V_\phi(s_t) - R_t \right)^2 \right] Lcritic(ϕ)=Et[(Vϕ(st)−Rt)2]
在 DeepSpeed Chat 的实现中,Critic Loss 引入了裁剪机制,限制值函数的更新幅度:
V ϕ ( s t ) clipped = clip ( V ϕ ( s t ) , V ϕ ( s t ) old − ϵ , V ϕ ( s t ) old + ϵ ) V_\phi(s_t)^{\text{clipped}} = \text{clip}(V_\phi(s_t), V_\phi(s_t)^{\text{old}} - \epsilon, V_\phi(s_t)^{\text{old}} + \epsilon) Vϕ(st)clipped=clip(Vϕ(st),Vϕ(st)old−ϵ,Vϕ(st)old+ϵ)
损失函数为:
L critic ( ϕ ) = 1 2 ⋅ E t [ max ( ( V ϕ ( s t ) − R t ) 2 , ( V ϕ ( s t ) clipped − R t ) 2 ) ] \mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2} \cdot \mathbb{E}_t \left[ \max \left( \left( V_\phi(s_t) - R_t \right)^2, \left( V_\phi(s_t)^{\text{clipped}} - R_t \right)^2 \right) \right] Lcritic(ϕ)=21⋅Et[max((Vϕ(st)−Rt)2,(Vϕ(st)clipped−Rt)2)]
这种设计能够限制值函数的剧烈变化,从而提高训练稳定性。
3. PPO 的实现:结合 DeepSpeed Chat
在 DeepSpeed Chat 中,PPO 的实现集中在以下几个核心部分。
3.1 Actor 和 Critic 损失的计算
在 train_rlhf
方法中,分别计算 Actor Loss 和 Critic Loss:
注:这段代码只是模拟DeepSpeed Chat,具体实现请看源代码。但是思路是一致的,这里为的是方便理解,进行的简化。
def train_rlhf(self, exp_data):# 提取输入数据logprobs, old_logprobs = exp_data["logprobs"], exp_data["old_logprobs"]values, old_values = exp_data["values"], exp_data["old_values"]rewards, returns = exp_data["rewards"], exp_data["returns"]advantages = returns - values #优势仅仅是模拟,具体实现请看下文中的源代码解析# 计算 Actor 损失# 和DeepSpeed Chat稍有区别,下文的源代码解析中有讲到log_ratio = (logprobs - old_logprobs) * exp_data["mask"]ratio = torch.exp(log_ratio)actor_loss1 = advantages * ratioactor_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)actor_loss = -torch.sum(torch.min(actor_loss1, actor_loss2) * exp_data["mask"]) / exp_data["mask"].sum()# 计算 Critic 损失values_clipped = torch.clamp(values,old_values - self.cliprange_value,old_values + self.cliprange_value,)vf_loss1 = (values - returns) ** 2vf_loss2 = (values_clipped - returns) ** 2critic_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * exp_data["mask"]) / exp_data["mask"].sum()return actor_loss, critic_loss
- Actor 损失:根据裁剪策略计算 ( L actor \mathcal{L}_{\text{actor}} Lactor)。
- Critic 损失:通过裁剪值函数计算 ( L critic \mathcal{L}_{\text{critic}} Lcritic)。
实际上,原仓库对它们进行了封装,分别计算两个loss。
下面是 https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py封装的函数,具体解析请参考笔者的另一篇博客: 基于DeepSpeed Chat详解 PPO 算法中的actor_loss_fn及其核心参数
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):## policy gradient losslog_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)pg_loss1 = -advantages * ratiopg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,1.0 + self.cliprange)pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()return pg_lossdef critic_loss_fn(self, values, old_values, returns, mask):## value lossvalues_clipped = torch.clamp(values,old_values - self.cliprange_value,old_values + self.cliprange_value,)if self.compute_fp32_loss:values = values.float()values_clipped = values_clipped.float()vf_loss1 = (values - returns)**2vf_loss2 = (values_clipped - returns)**2vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()return vf_loss
关于上文提到的优势计算,DeepSpeed Chat用到的是GAE版本的:
下面代码的解析请参考笔者的另一篇博客:深入理解 Generalized Advantage Estimation (GAE) 及其代码实现:以DeepSpeed-Chat中PPO算法使用为例
def get_advantages_and_returns(self, values, rewards, start):# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134lastgaelam = 0advantages_reversed = []length = rewards.size()[-1]for t in reversed(range(start, length)):nextvalues = values[:, t + 1] if t < length - 1 else 0.0delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]lastgaelam = delta + self.gamma * self.lam * lastgaelamadvantages_reversed.append(lastgaelam)advantages = torch.stack(advantages_reversed[::-1], dim=1)returns = advantages + values[:, start:]return advantages.detach(), returns
3.2 主训练循环
在 main.py
中,PPO 的训练循环如下:
for ppo_ep in range(args.ppo_epochs):for i, (exp_data, unsup_data) in enumerate(zip(exp_dataset, unsup_dataset)):# 训练 Actor 和 Criticactor_loss, critic_loss = trainer.train_rlhf(exp_data)# 记录损失actor_loss_sum += actor_loss.item()critic_loss_sum += critic_loss.item()average_reward += exp_data["rewards"].mean()
exp_data
:经验数据,包含策略概率(logprobs)、状态值(values)、回报(returns)等。train_rlhf()
:调用训练函数,返回 Actor 和 Critic 的损失。actor_loss_sum
和critic_loss_sum
:累积损失用于记录训练进展。
3.3 熵的缺失
值得注意的是,DeepSpeed Chat 的实现没有显式地加入 熵正则项。熵正则项通常用于增加策略的随机性,从而提高探索性。其常见形式为:
L entropy = E t [ − ∑ a π ( a ∣ s t ) log π ( a ∣ s t ) ] \mathcal{L}_{\text{entropy}} = \mathbb{E}_t \left[ -\sum_a \pi(a|s_t) \log \pi(a|s_t) \right] Lentropy=Et[−a∑π(a∣st)logπ(a∣st)]
在没有熵正则项的情况下,模型可能更快收敛到局部最优策略,而缺乏足够的探索。
4. PPO 的总损失函数
PPO 的总损失为 Actor Loss 和 Critic Loss 的加权和,通常还包括熵正则项:
L total = L actor + c 1 ⋅ L critic − c 2 ⋅ L entropy \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{actor}} + c_1 \cdot \mathcal{L}_{\text{critic}} - c_2 \cdot \mathcal{L}_{\text{entropy}} Ltotal=Lactor+c1⋅Lcritic−c2⋅Lentropy
- ( c 1 , c 2 c_1, c_2 c1,c2):权重超参数,用于平衡各项损失。
在 DeepSpeed Chat 的实现中,熵正则项未被显式加入,因此总损失实际上为:
L total = L actor + c 1 ⋅ L critic \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{actor}} + c_1 \cdot \mathcal{L}_{\text{critic}} Ltotal=Lactor+c1⋅Lcritic
5. 总结与改进建议
5.1 DeepSpeed Chat 的 PPO 实现特点
- Actor 和 Critic 的裁剪机制:通过裁剪策略和值函数更新,保证了训练的稳定性。
- 简化实现:在损失函数中省略了熵正则项,从而简化了实现。
5.2 改进建议
为了增强模型的探索性,可以加入熵正则项,并将其权重 (c_2) 调整为适当的值。示例代码如下:
# 计算熵正则项
entropy = -torch.sum(exp_data["logprobs"] * torch.exp(exp_data["logprobs"]) * exp_data["mask"]) / exp_data["mask"].sum()# 总损失函数
total_loss = actor_loss + self.vf_coef * critic_loss - self.entropy_coef * entropy
加入熵正则项后,模型可以在探索和利用之间实现更好的平衡。
5.3 总结
PPO 是一种高效且稳定的强化学习算法,其 Actor 和 Critic 的裁剪机制是其核心设计。DeepSpeed Chat 的实现体现了 PPO 的简化设计,同时也为研究者提供了扩展的空间,例如加入熵正则项、调整损失权重等。
为什么 Critic Loss 使用 max
函数?
在 Proximal Policy Optimization (PPO) 算法中,Critic Loss 使用了裁剪机制来限制值函数的更新幅度,以防止训练过程中的不稳定。损失函数形式如下:
L critic ( ϕ ) = 1 2 ⋅ E t [ max ( ( V ϕ ( s t ) − R t ) 2 , ( V ϕ ( s t ) clipped − R t ) 2 ) ] \mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2} \cdot \mathbb{E}_t \left[ \max \left( \left( V_\phi(s_t) - R_t \right)^2, \left( V_\phi(s_t)^{\text{clipped}} - R_t \right)^2 \right) \right] Lcritic(ϕ)=21⋅Et[max((Vϕ(st)−Rt)2,(Vϕ(st)clipped−Rt)2)]
其中:
- ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ):当前值网络对状态 ( s t s_t st ) 的估计。
- ( R t R_t Rt ):目标回报(实际回报)。
- ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ):裁剪后的值函数:
V ϕ ( s t ) clipped = clip ( V ϕ ( s t ) , V ϕ ( s t ) old − ϵ , V ϕ ( s t ) old + ϵ ) V_\phi(s_t)^{\text{clipped}} = \text{clip}\left( V_\phi(s_t), V_\phi(s_t)^{\text{old}} - \epsilon, V_\phi(s_t)^{\text{old}} + \epsilon \right) Vϕ(st)clipped=clip(Vϕ(st),Vϕ(st)old−ϵ,Vϕ(st)old+ϵ)
其中 ( V ϕ ( s t ) old V_\phi(s_t)^{\text{old}} Vϕ(st)old ) 是前一次的估计值,( ϵ \epsilon ϵ) 控制裁剪的范围。
为什么要用 max
函数?
max
函数的引入是为了在值函数更新时,限制过度估计带来的不稳定性,同时保证学习的效果。
- 第一项 ( ( V ϕ ( s t ) − R t ) 2 (V_\phi(s_t) - R_t)^2 (Vϕ(st)−Rt)2):表示当前值函数 ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ) 和目标回报 ( R t R_t Rt ) 的误差。
- 第二项 ( ( V ϕ ( s t ) clipped − R t ) 2 (V_\phi(s_t)^{\text{clipped}} - R_t)^2 (Vϕ(st)clipped−Rt)2):表示裁剪后的值函数 ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ) 和目标回报 ( R t R_t Rt ) 的误差。
通过 max
操作,PPO 选择误差较大的那一项来计算损失,确保以下两点:
- 稳定训练过程:当 ( V ϕ ( s t ) V_\phi(s_t) Vϕ(st) ) 变化过大时,裁剪机制 ( V ϕ ( s t ) clipped V_\phi(s_t)^{\text{clipped}} Vϕ(st)clipped ) 会限制更新幅度。
- 避免过度惩罚:即使值函数的更新被裁剪,也不会因为裁剪而导致损失函数过于偏离目标。
数值模拟解析
假设:
- ( V ϕ old = 10 V_\phi^{\text{old}} = 10 Vϕold=10 ):前一次值函数的估计值。
- ( ϵ = 2 \epsilon = 2 ϵ=2 ):裁剪范围。
- ( R t = 12 R_t = 12 Rt=12 ):目标回报。
我们分别计算以下三种情况:
- 当前值函数估计 ( V ϕ = 14 V_\phi = 14 Vϕ=14 )(过高估计)
- 当前值函数估计 ( V ϕ = 8 V_\phi = 8 Vϕ=8 )(过低估计)
- 当前值函数估计 ( V ϕ = 11 V_\phi = 11 Vϕ=11 )(合理范围内)
我们来计算每种情况下的 Critic Loss:
情况1:过高估计 ( V ϕ = 14 V_\phi = 14 Vϕ=14 )
-
裁剪前误差:
( V ϕ − R t ) 2 = ( 14 − 12 ) 2 = 4 (V_\phi - R_t)^2 = (14 - 12)^2 = 4 (Vϕ−Rt)2=(14−12)2=4 -
裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
V ϕ clipped = clip ( 14 , 10 − 2 , 10 + 2 ) = 12 V_\phi^{\text{clipped}} = \text{clip}(14, 10 - 2, 10 + 2) = 12 Vϕclipped=clip(14,10−2,10+2)=12
裁剪后误差:
( V ϕ clipped − R t ) 2 = ( 12 − 12 ) 2 = 0 (V_\phi^{\text{clipped}} - R_t)^2 = (12 - 12)^2 = 0 (Vϕclipped−Rt)2=(12−12)2=0 -
Critic Loss:
L critic = 1 2 ⋅ max ( 4 , 0 ) = 1 2 ⋅ 4 = 2 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(4, 0) = \frac{1}{2} \cdot 4 = 2 Lcritic=21⋅max(4,0)=21⋅4=2
情况2:过低估计 ( V ϕ = 8 V_\phi = 8 Vϕ=8 )
-
裁剪前误差:
( V ϕ − R t ) 2 = ( 8 − 12 ) 2 = 16 (V_\phi - R_t)^2 = (8 - 12)^2 = 16 (Vϕ−Rt)2=(8−12)2=16 -
裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
V ϕ clipped = clip ( 8 , 10 − 2 , 10 + 2 ) = 10 V_\phi^{\text{clipped}} = \text{clip}(8, 10 - 2, 10 + 2) = 10 Vϕclipped=clip(8,10−2,10+2)=10
裁剪后误差:
( V ϕ clipped − R t ) 2 = ( 10 − 12 ) 2 = 4 (V_\phi^{\text{clipped}} - R_t)^2 = (10 - 12)^2 = 4 (Vϕclipped−Rt)2=(10−12)2=4 -
Critic Loss:
L critic = 1 2 ⋅ max ( 16 , 4 ) = 1 2 ⋅ 16 = 8 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(16, 4) = \frac{1}{2} \cdot 16 = 8 Lcritic=21⋅max(16,4)=21⋅16=8
情况3:合理范围内 ( V ϕ = 11 V_\phi = 11 Vϕ=11 )
-
裁剪前误差:
( V ϕ − R t ) 2 = ( 11 − 12 ) 2 = 1 (V_\phi - R_t)^2 = (11 - 12)^2 = 1 (Vϕ−Rt)2=(11−12)2=1 -
裁剪后的 ( V ϕ clipped V_\phi^{\text{clipped}} Vϕclipped ):
V ϕ clipped = clip ( 11 , 10 − 2 , 10 + 2 ) = 11 V_\phi^{\text{clipped}} = \text{clip}(11, 10 - 2, 10 + 2) = 11 Vϕclipped=clip(11,10−2,10+2)=11
裁剪后误差:
( V ϕ clipped − R t ) 2 = ( 11 − 12 ) 2 = 1 (V_\phi^{\text{clipped}} - R_t)^2 = (11 - 12)^2 = 1 (Vϕclipped−Rt)2=(11−12)2=1 -
Critic Loss:
L critic = 1 2 ⋅ max ( 1 , 1 ) = 1 2 ⋅ 1 = 0.5 \mathcal{L}_{\text{critic}} = \frac{1}{2} \cdot \max(1, 1) = \frac{1}{2} \cdot 1 = 0.5 Lcritic=21⋅max(1,1)=21⋅1=0.5
总结
通过以上数值模拟,我们可以看到:
- 过高估计时:裁剪机制限制了 ( V ϕ V_\phi Vϕ ) 的更新幅度,Critic Loss 较小。
- 过低估计时:裁剪机制限制 ( V ϕ V_\phi Vϕ ) 过度下降,但仍允许一定程度的更新,Critic Loss 较大。
- 合理范围内时:Critic Loss 最小,表示值函数估计已经接近目标回报。
Critic Loss 使用 max
的意义
- 避免值函数更新过大(通过裁剪限制)。
- 同时保证训练仍然能够向正确的方向优化(选择较大误差)。
- 提高训练的稳定性,减少梯度爆炸或值函数震荡的风险。
这种设计是 PPO 算法稳定性的重要来源。
后记
2024年12月14日15点14分于上海,在GPT4o大模型辅助下完成。