基于DeepSpeed Chat详解 PPO 算法中的actor_loss_fn及其核心参数

详解 PPO 算法中的 actor_loss_fn 及其核心参数


1. 引言

在强化学习中,PPO(Proximal Policy Optimization,近端策略优化)算法是一种经典且高效的策略优化方法。它通过重要性采样(Importance Sampling)和策略剪切(Clipping)来平衡策略改进的幅度模型训练的稳定性

这篇博客将重点解析 DeepSpeed Chat 的 actor_loss_fn 函数,代码链接: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py#L44,结合代码详细说明其作用,并通过一个具体的例子,说明模型生成的 response 是如何对应到 logprobsold_logprobsadvantagesmask 这几个参数的。


2. 代码解析

下面是 actor_loss_fn 函数的实现:

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):# 计算 log 概率比log_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)# 策略梯度损失(两种形式)pg_loss1 = -advantages * ratiopg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)# 选择损失中的最大值,并加权求和(masked)pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()return pg_loss
2.1 参数说明
  • logprobs:当前策略(即 Actor 网络)在生成的 response 上计算得到的 log 概率。这是动态更新的,随着 Actor 策略的更新而变化。
  • old_logprobs:旧策略(之前保存的 Actor 策略)在相同 response 上计算得到的 log 概率。这是固定的,不随当前策略变化。
  • advantages:旧策略计算出的优势函数值,表示每个 token 的相对价值。这也是固定的。
  • mask:一个布尔型张量,用于标记有效的 token(即非填充部分)。例如,response 的部分 token 有效,其余部分被填充为无效。
2.2 核心逻辑
  1. 计算 Log 概率比
    log_ratio = ( logprobs − old_logprobs ) ⋅ mask \text{log\_ratio} = (\text{logprobs} - \text{old\_logprobs}) \cdot \text{mask} log_ratio=(logprobsold_logprobs)mask
    通过当前策略和旧策略的 log 概率差,计算重要性比率的对数。

  2. 将 Log 概率比转换为比例
    ratio = exp ⁡ ( log_ratio ) \text{ratio} = \exp(\text{log\_ratio}) ratio=exp(log_ratio)
    转换为实际的概率比。

  3. 定义两种策略梯度损失

    • 未剪切的损失
      pg_loss1 = − advantages ⋅ ratio \text{pg\_loss1} = -\text{advantages} \cdot \text{ratio} pg_loss1=advantagesratio
    • 剪切后的损失
      pg_loss2 = − advantages ⋅ clip ( ratio , 1 − ϵ , 1 + ϵ ) \text{pg\_loss2} = -\text{advantages} \cdot \text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) pg_loss2=advantagesclip(ratio,1ϵ,1+ϵ)
      通过限制 ( ratio \text{ratio} ratio) 在 ( [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1ϵ,1+ϵ]) 范围内,避免策略更新幅度过大,防止训练不稳定。
  4. 选取最大损失
    对每个 token,取两种损失中的较大值,以确保更新过程不会偏离策略优化目标。

  5. 加权平均损失
    使用 mask 对损失加权,避免填充值对计算的干扰。


3. 参数与 Response 的对应关系

为了更好地理解 logprobsold_logprobsadvantagesmask,我们通过一个具体例子来说明:

假设输入
  • Prompt"Translate the following text into French: 'Hello, how are you?'"
  • Response(生成的句子):"Bonjour, comment ça va?"
具体参数计算过程
  1. logprobs
    当前策略对 response 每个 token 的生成概率:

    • logprobs 可能为:
      [-1.2, -0.8, -1.5, -0.6, -1.0]  # 对应 "Bonjour,", "comment", "ça", "va", "?"
      
    • 它是通过当前 Actor 模型的 softmax 输出取对数后得到的。
  2. old_logprobs
    旧策略对相同 response 的生成概率:

    • old_logprobs 可能为:
      [-1.0, -0.9, -1.4, -0.7, -0.9]  # 对应 "Bonjour,", "comment", "ça", "va", "?"
      
  3. advantages
    每个 token 的优势值,表示生成该 token 的相对价值:

    • 例如:
      [0.5, 0.8, 0.3, 0.6, 0.4]  # 对应 "Bonjour,", "comment", "ça", "va", "?"
      
    • 这些值通过旧策略与环境交互(或奖励模型)计算得到,表示每个 token 对整体策略的贡献。
  4. mask
    一个布尔型标记,有效部分为 1,填充部分为 0:

    • 例如,假设 response 中没有填充值:
      [1, 1, 1, 1, 1]
      

4. 为什么引入 logprobs 和策略剪切?

在 PPO 中,策略更新会参考旧策略,但如果更新幅度太大(即 ( ratio \text{ratio} ratio) 偏离 1 太远),可能会导致模型“训歪”:

  • 过度更新:策略过快改变,导致新策略偏离旧策略太远,破坏学习稳定性。
  • 欠更新:策略改变过小,无法显著优化目标函数。

通过引入 logprobs 和剪切机制(( clip \text{clip} clip)),我们可以控制每次更新的幅度,确保:

  • 更新不偏离旧策略太远。
  • 改进的同时保留策略的多样性和探索能力。

5. 实例代码

以下是基于 PyTorch 的实现,模拟计算 actor_loss

import torch# 模拟输入
logprobs = torch.tensor([-1.2, -0.8, -1.5, -0.6, -1.0])  # 当前策略的 log 概率
old_logprobs = torch.tensor([-1.0, -0.9, -1.4, -0.7, -0.9])  # 旧策略的 log 概率
advantages = torch.tensor([0.5, 0.8, 0.3, 0.6, 0.4])  # 每个 token 的优势值
mask = torch.tensor([1, 1, 1, 1, 1])  # 有效部分的掩码
cliprange = 0.2  # 剪切范围# 计算 log 概率比和比率
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)# 策略梯度损失
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()# 输出结果
print("Log Ratios:", log_ratio)
print("Ratios:", ratio)
print("Policy Gradient Loss:", pg_loss.item())

6. 总结

通过这篇博客,我们详细介绍了 actor_loss_fn 的实现逻辑及其关键参数在 PPO 中的作用:

  • logprobsold_logprobs 是新旧策略的桥梁,通过重要性采样连接两者。
  • advantages 衡量生成 token 的相对价值。
  • mask 避免无效 token 对损失的影响。
  • 引入剪切机制确保训练稳定。

通过具体例子和代码演示,我们更直观地理解了这些参数与 PPO 的关系,以及它们在策略优化中的关键作用。

【1】pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum() 解释

要理解这句代码 pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum(),我们可以通过一个具体的例子,模拟实际计算过程,特别关注 mask 的作用和 mask.sum() 的意义。


代码结构解析

  1. pg_loss1pg_loss2

    • 两种策略梯度损失的计算:
      • pg_loss1 是未剪切的策略梯度损失。
      • pg_loss2 是剪切后的策略梯度损失(通过 torch.clamp 限制比率的范围)。
    • torch.max(pg_loss1, pg_loss2):对每个 token,取两种损失中的较大值。
  2. 乘以 mask

    • mask 是一个布尔型张量,用于标记哪些 token 是有效的,哪些是无效的。
    • 乘以 mask 的作用是:将无效的 token 对应的损失置为 0,避免它们对整体损失的计算产生干扰。
  3. torch.sum(... * mask)

    • 计算有效 token 的总损失,忽略无效 token。
  4. / mask.sum()

    • 归一化:用有效 token 的数量(mask.sum())对总损失进行归一化,确保损失是平均的,而不是直接被无效 token 的数量拉低。

实际数值模拟

假设我们有以下数据:

输入数据
  • pg_loss1:未剪切的损失

    pg_loss1 = [-0.5, -0.8, -0.3, -1.2, -0.6]  # 每个 token 的损失
    
  • pg_loss2:剪切后的损失

    pg_loss2 = [-0.4, -0.9, -0.5, -0.8, -0.7]
    
  • mask:有效 token 标记

    mask = [1, 1, 1, 0, 0]  # 前 3 个 token 有效,后两个 token 无效
    
计算步骤
  1. torch.max(pg_loss1, pg_loss2)
    对每个 token,取 pg_loss1pg_loss2 中的较大值(注意:负数值较小的绝对值是较大值)。

    max_loss = torch.max(pg_loss1, pg_loss2)= [-0.4, -0.8, -0.3, -0.8, -0.6]
    
  2. 乘以 mask
    通过与 mask 相乘,无效 token(mask = 0)对应的损失被置为 0:

    masked_loss = max_loss * mask= [-0.4, -0.8, -0.3, 0.0, 0.0]
    
  3. torch.sum(masked_loss)
    求有效 token 的总损失:

    total_loss = torch.sum(masked_loss)= -0.4 + (-0.8) + (-0.3) + 0.0 + 0.0= -1.5
    
  4. mask.sum()
    统计有效 token 的数量:

    mask_sum = mask.sum()= 1 + 1 + 1 + 0 + 0= 3
    
  5. 归一化损失
    将总损失除以有效 token 数量,计算平均损失:

    pg_loss = total_loss / mask_sum= -1.5 / 3= -0.5
    

乘以 mask 的作用

在这个例子中,mask 的作用是:

  1. 将无效 token(mask = 0)对应的损失设置为 0,防止这些无效 token 对总损失产生影响。
    • 如果没有乘以 mask,则 pg_loss 会错误地包含无效 token 的损失。
  2. 确保只有有效 token 的损失被纳入计算。

归一化 (/ mask.sum()) 的作用

在训练过程中,不同 batch 的有效 token 数量可能不一样。例如:

  • 一个 batch 的 mask 有效部分是 3 个 token。
  • 另一个 batch 的 mask 有效部分是 5 个 token。

如果不对总损失归一化,那么:

  • 有效 token 数量多的 batch 会对损失计算产生更大的权重。
  • 归一化后,每个 token 的贡献是平均的,确保训练的稳定性。

总结

通过这个例子,可以清楚地看到:

  1. torch.max(pg_loss1, pg_loss2):用于选择更大的损失,平衡未剪切和剪切更新之间的效果。
  2. 乘以 mask:屏蔽无效 token 的损失,专注于有效部分。
  3. 归一化 (/ mask.sum()):确保损失与有效 token 的数量无关,保持不同 batch 的训练一致性。

最终,这段代码实现了 PPO 策略更新时的核心逻辑:通过限制每次更新幅度,稳定地优化 Actor 策略,同时屏蔽无效 token 对损失的干扰。

【2】 为什么在 torch.max(pg_loss1, pg_loss2) 中使用 max 而不是 min

PPO(Proximal Policy Optimization)算法中的损失函数设计目标是 限制策略更新幅度,确保策略的稳定性和性能提升。这里的 max 并非简单地取大值,而是通过裁剪(clipping)机制实现 选择对策略有利的损失,同时约束更新幅度。


PPO 的核心损失函数回顾

PPO的目标函数为:
L CLIP ( θ ) = E [ min ⁡ ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min \left( r_t(\theta) \hat{A}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] LCLIP(θ)=E[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\text{old}}(a_t | s_t)} rt(θ)=πold(atst)πθ(atst) ):新旧策略的比例。
  • ( A ^ t \hat{A}_t A^t ):优势值(Advantage)。
  • ( ϵ \epsilon ϵ ):裁剪范围(比如 ( 0.2 0.2 0.2 ))。
  • clip:将比例 ( r t r_t rt ) 裁剪到 ( [ 1 − ϵ , 1 + ϵ ] [1 - \epsilon, 1 + \epsilon] [1ϵ,1+ϵ] )。

PPO 的损失函数通过裁剪机制,限制策略更新幅度,防止模型发生 过大更新,从而训练不稳定。

在实现中,pg_loss1pg_loss2 分别为:

  • pg_loss1 = - advantages * ratio(未裁剪的损失)
  • pg_loss2 = - advantages * clipped_ratio(裁剪后的损失)

为什么使用 max 而不是 min

  1. 损失的符号
    PPO 使用的是 负的优势值(-Advantage) 来计算损失,因为优化目标是 最小化损失,而优势值的正负决定了动作的好坏。

    • 当 ( A ^ t > 0 \hat{A}_t > 0 A^t>0 ):表示动作优于基准,应该增加概率。
    • 当 ( A ^ t < 0 \hat{A}_t < 0 A^t<0 ):表示动作劣于基准,应该减少概率。
  2. max 的意义
    在两种情况下,PPO 的目标是:

    • 对于好的动作(优势 ( > 0 > 0 >0 )),我们允许更新比例 ( r t r_t rt ) 稍微增大,但不能超出裁剪范围。
    • 对于坏的动作(优势 ( < 0 < 0 <0 )),我们允许更新比例 ( r t r_t rt ) 稍微减小,但不能超出裁剪范围。

    通过 max,我们实现了以下效果:

    • 当策略偏离较少时,使用未裁剪的损失(pg_loss1)。
    • 当策略偏离较多时,使用裁剪后的损失(pg_loss2),限制更新幅度。
  3. 为什么不是 min
    如果使用 min,会选择较小的损失值,反而可能导致:

    • 对好的动作,更新不足(不增大概率)。
    • 对坏的动作,更新过度(减少概率过大)。
      这将违背 PPO 的设计目标,无法有效地限制策略更新幅度,导致训练不稳定。

数值示例分析

假设我们有以下输入数据:

  • 优势值(Advantage):[2.0, -1.0]
  • 比例(ratio)[1.1, 0.7]
  • 裁剪范围[0.8, 1.2]

根据 PPO 公式:

  1. 未裁剪的损失pg_loss1):
    pg_loss1 = − advantages × ratio \text{pg\_loss1} = - \text{advantages} \times \text{ratio} pg_loss1=advantages×ratio

  2. 裁剪后的损失pg_loss2):
    pg_loss2 = − advantages × clip ( ratio , 1 − ϵ , 1 + ϵ ) \text{pg\_loss2} = - \text{advantages} \times \text{clip}(\text{ratio}, 1 - \epsilon, 1 + \epsilon) pg_loss2=advantages×clip(ratio,1ϵ,1+ϵ)


代码实现

import torch# 输入数据
advantages = torch.tensor([2.0, -1.0])  # 优势值
ratio = torch.tensor([1.1, 0.7])  # 更新比例
cliprange = 0.2  # 裁剪范围# 计算未裁剪的损失
pg_loss1 = -advantages * ratio# 计算裁剪后的损失
clipped_ratio = torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss2 = -advantages * clipped_ratio# 取 max
final_loss = torch.max(pg_loss1, pg_loss2)print("pg_loss1 (未裁剪):", pg_loss1)
print("pg_loss2 (裁剪后):", pg_loss2)
print("Final Loss (max):", final_loss)

输出结果

pg_loss1 (未裁剪): tensor([-2.2000,  0.7000])
pg_loss2 (裁剪后): tensor([-2.2000,  0.8000])
Final Loss (max): tensor([-2.2000,  0.8000])

解释

  1. 第一个样本(优势 > 0)

    • pg_loss1 = -2.2(未裁剪)
    • pg_loss2 = -2.2(裁剪后的损失与未裁剪相等)
    • max,最终损失为 -2.2
  2. 第二个样本(优势 < 0)

    • pg_loss1 = 0.7(未裁剪)
    • pg_loss2 = 0.8(裁剪后的比例被限制为 0.8)
    • max,最终损失为 0.8,表示更新受到了裁剪限制。

结论

  • max(pg_loss1, pg_loss2) 的作用:确保在未裁剪损失和裁剪损失之间取较大值,限制策略的更新幅度,稳定训练过程。
  • 损失函数关系
    • 对于 优势 ( > 0 ),我们防止概率增大过多。
    • 对于 优势 ( < 0 ),我们防止概率减小过多。
  • 使用 min 会导致训练不稳定,无法有效地限制更新幅度,违反 PPO 的核心设计理念。

PPO 损失函数前面是否有负号取决于损失函数的最优化目标,即最大化策略性能还是最小化损失。


【3】为什么有的资料ppo的损失函数前面有负号,-min(),而有的材料前面没有负号?

1. 从优化目标的角度解释

  • 强化学习的目标:最大化策略的累积回报 ( R t R_t Rt )。
    由于优化器(例如 Adam)默认是最小化损失,我们需要对最大化目标取负数,从而转换为最小化损失问题。

因此:

  • 如果目标是最大化策略性能,则公式为:
    maximize  L CLIP ( θ ) \text{maximize } L^{\text{CLIP}}(\theta) maximize LCLIP(θ)
  • 为了与优化器配合(最小化),我们取负号:
    minimize  − L CLIP ( θ ) \text{minimize } -L^{\text{CLIP}}(\theta) minimize LCLIP(θ)

带负号的形式
actor_loss = − min ⁡ ( r t ⋅ Adv t , clip ( r t , 1 − ϵ , 1 + ϵ ) ⋅ Adv t ) \text{actor\_loss} = - \min \left( r_t \cdot \text{Adv}_t, \, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t \right) actor_loss=min(rtAdvt,clip(rt,1ϵ,1+ϵ)Advt)


2. 从损失函数的数值角度

  • 优势值 ( Adv t \text{Adv}_t Advt )
    • 当 ( Adv t > 0 \text{Adv}_t > 0 Advt>0 ):表示当前动作优于基准,应该增大概率。
    • 当 ( Adv t < 0 \text{Adv}_t < 0 Advt<0 ):表示当前动作劣于基准,应该减小概率。

如果前面有负号:

  • 最大化优势值:对于 ( Adv t > 0 \text{Adv}_t > 0 Advt>0 ),损失会减小(概率增大);对于 ( Adv t < 0 \text{Adv}_t < 0 Advt<0 ),损失会增大(概率减小)。

如果前面没有负号:

  • 直接是奖励优化形式,适合展示最大化回报的目标,但在实际代码中优化器不直接适配这种形式。

3. 两种形式的适用场景

  • 带负号的形式:通常出现在代码实现中,因为优化器默认执行最小化损失。
  • 不带负号的形式:出现在理论公式或论文中,强调最大化目标。

结论

带负号的形式是为了将“最大化回报”转化为优化器的最小化损失问题,而不带负号的形式仅仅表示理论上最大化优势目标。实际实现中,大多数代码会包含负号。


【4】把pg_loss1和2的负号去掉,最后pg_loss取min再加负号,这样可以吗?

下面是我问GPT的问题:

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):## policy gradient losslog_ratio = (logprobs - old_logprobs) * maskratio = torch.exp(log_ratio)pg_loss1 = advantages * ratio pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum() return pg_loss

代码修改成这样可以吗?把pg_loss1和2的负号去掉,最后pg_loss取min,再加负号,这样可以吗?或者不可以吗?为什么

修改后的代码是可以的!在 PPO 的优化目标中,加负号min 的组合恰好符合 PPO 损失函数的设计。以下会详细解释为什么这样修改是合理的。


1. PPO 损失函数的原理

PPO 的目标是最大化期望回报,同时限制策略更新的幅度,以保证训练的稳定性。其目标函数为:
L PPO ( θ ) = E t [ min ⁡ ( r t ( θ ) ⋅ Adv t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ Adv t ) ] L^{\text{PPO}}(\theta) = \mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] LPPO(θ)=Et[min(rt(θ)Advt,clip(rt(θ),1ϵ,1+ϵ)Advt)]
其中:

  • ( r t ( θ ) = π θ ( a t ∣ s t ) π old ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)} rt(θ)=πold(atst)πθ(atst) ) 是新旧策略的概率比。
  • ( Adv t \text{Adv}_t Advt ) 是优势函数,表示当前动作相对于平均水平的好坏。
  • 裁剪操作:通过 ( clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) clip(rt(θ),1ϵ,1+ϵ) ) 限制比率的变化幅度。

为了最大化 ( L PPO L^{\text{PPO}} LPPO ),通常将其转化为一个最小化问题,即在前面加上负号:
actor_loss = − E t [ min ⁡ ( r t ( θ ) ⋅ Adv t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=Et[min(rt(θ)Advt,clip(rt(θ),1ϵ,1+ϵ)Advt)]


2. 代码分析

修改后的代码:

log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
主要修改点:
  1. 去掉了 pg_loss1pg_loss2 的负号。
  2. 在最终计算 pg_loss 时,加了负号。
  3. 使用 torch.min(pg_loss1, pg_loss2)
分析:
  1. 去掉 pg_loss1pg_loss2 的负号
    这里 pg_loss1pg_loss2 都表示未裁剪和裁剪后的损失,原始 PPO 损失的优化目标是最大化优势乘以比率。直接使用优势和比率(不加负号)可以保留目标的最大化形式。

  2. 加负号到 pg_loss 前面
    PPO 的目标是最大化 ( L PPO L^{\text{PPO}} LPPO ),但优化器通常是最小化损失函数,因此需要加负号将最大化目标转化为最小化问题。

  3. 使用 torch.min(pg_loss1, pg_loss2)

    • pg_loss1:未裁剪的情况 ( Adv t ⋅ r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advtrt(θ) )。
    • pg_loss2:裁剪后的情况 ( Adv t ⋅ clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advtclip(rt(θ),1ϵ,1+ϵ) )。
    • min:选取两者中较小的值,目的是在不损害回报的情况下,限制策略更新幅度,从而保证训练的稳定性。

3. 为什么这样可以?

PPO 的损失函数最终形式是:
actor_loss = − E t [ min ⁡ ( r t ( θ ) ⋅ Adv t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ Adv t ) ] \text{actor\_loss} = -\mathbb{E}_t \left[ \min\left(r_t(\theta) \cdot \text{Adv}_t, \, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot \text{Adv}_t\right) \right] actor_loss=Et[min(rt(θ)Advt,clip(rt(θ),1ϵ,1+ϵ)Advt)]

与修改后的代码一一对应:

  • pg_loss1 表示未裁剪的部分 ( Adv t ⋅ r t ( θ ) \text{Adv}_t \cdot r_t(\theta) Advtrt(θ) )。
  • pg_loss2 表示裁剪后的部分 ( Adv t ⋅ clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) Advtclip(rt(θ),1ϵ,1+ϵ) )。
  • torch.min(pg_loss1, pg_loss2):实现了公式中的 min
  • -torch.sum(...) / mask.sum():加了负号,将最大化目标转化为最小化损失,并进行归一化处理。

这样写在数学意义和计算结果上都是正确的,符合 PPO 的原理。


4. 总结

  • 修改后的代码是合理的,符合 PPO 损失函数的设计。
  • 去掉 pg_loss1pg_loss2 的负号,并在最后的 pg_loss 上加负号,可以正确地实现最大化期望回报的目标。
  • 使用 torch.min 选择最保守的更新方向,保证 PPO 的稳定性。

这种写法与标准的 PPO 损失函数完全一致,功能正确且符合原理。

DeepSpeed-Chat 中的 PPO 实现,和你提出的代码虽然看起来形式有所不同,但两者最终达到的效果是一样的。从理论上讲,它们都是在实现 PPO 损失函数,差别在于写法的具体细节。

我们一步步来对比分析:


【5】那为什么DeepSpeed-Chat采用前者?在pg_loss1前面加负号?

1. DeepSpeed-Chat 中的代码

在 DeepSpeed-Chat PPO Trainer 中,损失函数是这样写的:

log_ratio = (logprobs - old_logprobs)
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
损失函数分析
  • -advantages * ratio:表示未裁剪的策略梯度损失(带负号是为了最小化损失)。
  • -advantages * torch.clamp(...):表示裁剪后的策略梯度损失(带负号)。
  • torch.max(pg_loss1, pg_loss2):选取两者中 最大 的值(注意是负数中的最大值),它实现了 PPO 的 min 损失 的效果。

2. 你修改的代码

你的修改代码为:

pg_loss1 = advantages * ratio
pg_loss2 = advantages * torch.clamp(ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)
pg_loss = -torch.sum(torch.min(pg_loss1, pg_loss2) * mask) / mask.sum()
损失函数分析
  • advantages * ratioadvantages * torch.clamp(...):没有负号,表示直接的策略梯度目标。
  • torch.min(pg_loss1, pg_loss2):取两者中最小值。
  • 负号 - 加在最终损失上:将最大化目标转化为最小化。

3. 两者的数学等价性

从数学角度看:

  1. DeepSpeed 的写法 使用 torch.max(pg_loss1, pg_loss2)负的优势

    • ( pg_loss1 = − Adv t ⋅ r t ( θ ) \text{pg\_loss1} = - \text{Adv}_t \cdot r_t(\theta) pg_loss1=Advtrt(θ) )
    • ( pg_loss2 = − Adv t ⋅ clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{pg\_loss2} = - \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=Advtclip(rt(θ),1ϵ,1+ϵ) )
    • 取最大值:因为两者都是负数,取负数中的最大值,相当于取正数中的最小值(原始的 min)。
  2. 你的写法 直接取 min,然后在结果前加负号:

    • ( pg_loss1 = Adv t ⋅ r t ( θ ) \text{pg\_loss1} = \text{Adv}_t \cdot r_t(\theta) pg_loss1=Advtrt(θ) )
    • ( pg_loss2 = Adv t ⋅ clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{pg\_loss2} = \text{Adv}_t \cdot \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) pg_loss2=Advtclip(rt(θ),1ϵ,1+ϵ) )
    • 取最小值,再加负号,得到负的损失。

为什么 DeepSpeed 选择前者?

虽然两种写法是等价的,但 DeepSpeed-Chat 选择前者的原因主要有以下几点:

  1. 代码风格和一致性
    DeepSpeed 的代码风格更偏向于将负号提前应用到 pg_loss1pg_loss2 上。这样可以让损失计算的逻辑和大部分框架保持一致,方便阅读和维护。

  2. 梯度计算的数值稳定性
    带负号的损失(pg_loss1pg_loss2)在早期阶段就应用负号,可能减少后续计算过程中的误差传播,更加数值稳定。

  3. 易于对接优化器
    优化器通常最小化损失函数。如果损失的定义是负的优势乘以比例,这样的形式可以直接传入优化器,避免在外部再引入负号。

  4. 便于调试和日志记录
    在调试过程中,直接观察 pg_loss1pg_loss2 的值(已经是负值)会更加直观,可以更容易判断未裁剪和裁剪部分的差异。


总结

  • 两者数学上等价,因为取负号和取最小值/最大值的顺序不同,但本质上实现了同样的损失函数。
  • DeepSpeed 选择前者是出于代码风格、一致性、数值稳定性和易于调试的考量。
  • 如果你修改后的代码符合 PPO 的原理,且没有引入其他数值问题(如梯度异常),理论上是完全可行的。

对于实际框架而言,写法的选择更多是代码规范和工程设计的考虑,而不是数学上的根本区别。

后记

2024年12月14日12点36分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63145.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

D3 基础1

D3 D3.js (Data-Driven Documents) 是一个基于 JavaScript 的库&#xff0c;用于生成动态、交互式数据可视化。它通过操作文档对象模型 (DOM) 来生成数据驱动的图形。官方网站是 https://d3js.org/ <!DOCTYPE html> <html lang"en"><head><me…

基线检查:Windows安全基线.【手动 || 自动】

基线定义 基线通常指配置和管理系统的详细描述&#xff0c;或者说是最低的安全要求&#xff0c;它包括服务和应用程序设置、操作系统组件的配置、权限和权利分配、管理规则等。 基线检查内容 主要包括账号配置安全、口令配置安全、授权配置、日志配置、IP通信配置等方面内容&…

Python -- Linux中的Matplotlib图中无法显示中文 (中文为方框)

目的 用matplotlib生成的图中文无法正常显示 方法 主要原因: 没找到字体 进入windows系统的C:\Windows\Fonts目录, 复制自己想要的字体 粘贴到Linux服务器中对应python文件所处的文件夹内 设置字体: 设置好字体文件的路径在需要对字体设置的地方设置字体 效果 中文正常显…

快速理解类的加载过程

当程序主动使用某个类时&#xff0c;如果该类还未加载到内存中&#xff0c;则系统会通过如下三个步骤来对该类进行初始化&#xff1a; 1.加载&#xff1a;将class文件字节码内容加载到内存中&#xff0c;并将这些静态数据转换成方法区的运行时数据结构&#xff0c;然后生成一个…

搭建 Elasticsearch 集群:完整教程

本文将详细介绍如何在 Linux 环境下搭建一个 Elasticsearch 集群&#xff0c;涵盖环境准备、配置优化、服务启动等多个环节。 一、环境准备 创建安装目录 mkdir /es cd /es解压 Elasticsearch 安装包 tar -xzf elasticsearch-7.10.1-linux-x86_64.tar.gz -C /es配置环境变量 编…

宝塔-docker拉取宝塔镜像,并运行宝塔镜像

宝塔-拉取宝塔镜像&#xff0c;并运行镜像 第1步&#xff1a;查询 docker search btpanel/baota此docker镜像由堡塔安全官方发布&#xff0c;镜像版本为宝塔面板9.2.0正式版和9.0.0_lts 稳定版&#xff0c;镜像会随着宝塔面板更新。 目前支持x86_64和arm架构可供下载使用 版本…

使用 Valgrind 检测 C 程序中的内存问题 -基础教程

内存泄漏是许多 C 语言程序中的常见问题&#xff0c;它不仅会导致程序性能下降&#xff0c;甚至可能让系统崩溃。为了检测和修复这些问题&#xff0c;Valgrind 是一个非常强大的工具&#xff0c;它可以帮助我们分析 C 程序中的内存使用情况&#xff0c;检测内存泄漏、越界访问、…

穷举vs暴搜vs深搜vs回溯vs剪枝专题一>子集

题目&#xff1a; 两个方法本质就是决策树的画法不同 方法一解析&#xff1a; 代码&#xff1a; class Solution {private List<List<Integer>> ret;//返回结果private List<Integer> path;//记录路径&#xff0c;注意返回现场public List<List<Int…

leecode双指针部分题目

leecode双指针部分题目 1. 验证回文串2. 判断子序列3. 两数之和 II - 输入有序数组4. 盛最多水的容器5. 三数之和 1. 验证回文串 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后&#xff0c;短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 …

Web 应用如何使用sqlite?使用 sql.js 实现前端 SQLite 数据库操作

前言 在 Web 应用开发中&#xff0c;前端数据处理的重要性日益增加。为了实现更高效的前端数据管理&#xff0c;特别是在处理结构化数据时&#xff0c;sql.js 提供了一个出色的解决方案。sql.js 是将 SQLite 数据库编译为 JavaScript 的库&#xff0c;允许开发者在浏览器环境中…

docker 安装 mysql8.0容器外无法连接

文章目录 概要问题描述解决方案其他命令 概要 主要是mysql5.7和mysql8.0的兼容性问题。 排查了很久 其实就是配置文件的一句话的事情 感觉mysql8.0更为严谨 这样可能是考虑杜绝一些漏洞吧 问题描述 在容器内 netstat -an | grep 3306 都不行 在容器外 netstat -an | grep 2…

TCP协议简单分析和握手挥手过程

TCP介绍 TCP是可靠的传输层协议&#xff0c;建立连接之前会经历3次握手的阶段。 确认机制&#xff1a;接受方 收到数据之后会向 发送方 回复ACK重传机制&#xff1a;发送方 在一定时间内没有收到 接收方的ACK就会重新发送 握手目的&#xff1a;与端口建立连接 TCP的三次握手 …

VisualStudio vsix插件自动加载

本文介绍如何在Visual Studio扩展中实现PackageRegistration&#xff0c;包括设置UseManagedResourcesOnly为true&#xff0c;允许背景加载&#xff0c;并针对C#、VB、F#项目提供自动装载&#xff0c;附官方文档链接。增加以下特性即可…… [PackageRegistration(UseManagedRe…

opencv所有常见函数

一、opencv图像操作 二、opencv图像的数值运算 三、opencv图像的放射变换 四、opencv空间域图像滤波 五、图像灰度化与直方图 六、形态学图像处理 七、阈值处理与边缘检测 八、轮廓和模式匹配

【Excel】单元格分列

目录 分列&#xff08;新手友好&#xff09; 1. 选中需要分列的单元格后&#xff0c;选择 【数据】选项卡下的【分列】功能。 2. 按照分列向导提示选择适合的分列方式。 3. 分好就是这个样子 智能分列&#xff08;进阶&#xff09; 高级分列 Tips&#xff1a; 新手推荐基…

【STM32练习】基于STM32的PM2.5环境监测系统

一.项目背景 最近为了完成老师交付的任务&#xff0c;遂重制了一下小项目用STM32做一个小型的环境监测系统。 项目整体示意框图如下&#xff1a; 二.器件选择 单片机&#xff08;STM32F103&#xff09;数字温湿度模块&#xff08;DHT11&#xff09;液晶显示模块&#xff08;0.8…

《开源数据:开启信息共享与创新的宝藏之门》

《开源数据&#xff1a;开启信息共享与创新的宝藏之门》 一、开源数据概述&#xff08;一&#xff09;开源数据的定义&#xff08;二&#xff09;开源数据的发展历程 二、开源数据的优势&#xff08;一&#xff09;成本效益优势&#xff08;二&#xff09;灵活性与可定制性&…

ReactPress最佳实践—搭建导航网站实战

Github项目地址&#xff1a;https://github.com/fecommunity/easy-blog 欢迎Star。 近期&#xff0c;阮一峰在科技爱好者周刊第 325 期中推荐了一款开源工具——ReactPress&#xff0c;ReactPress一个基于 Next.js 的博客和 CMS 系统&#xff0c;可查看 demo站点。&#xff08;…

2024,大模型杀进“决赛圈”

Henry Chesbrough在著作《通过技术创新盈利势在必行》中&#xff0c;曾提出过一个创新的“漏斗模型”。开放式创新一开始鼓励百花齐放&#xff0c;但最终只有10%的技术能够通过这个漏斗&#xff0c;成功抵达目标市场target market&#xff0c;进入到商业化与产业化的下一个阶段…

STM8单片机学习笔记·GPIO的片上外设寄存器

目录 前言 IC基本定义 三极管基础知识 单片机引脚电路作用 STM8GPIO工作模式 GPIO外设寄存器 寄存器含义用法 CR1&#xff1a;Control Register 1 CR2&#xff1a;Control Register 2 ODR&#xff1a;Output Data Register IDR&#xff1a;Input Data Register 赋值…