InstructGPT的四阶段:预训练、有监督微调、奖励建模、强化学习涉及到的公式解读

1. 预训练

在这里插入图片描述

1. 语言建模目标函数(公式1):

L 1 ( U ) = ∑ i log ⁡ P ( u i ∣ u i − k , … , u i − 1 ; Θ ) L_1(\mathcal{U}) = \sum_{i} \log P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) L1(U)=ilogP(uiuik,,ui1;Θ)

  • 解释
    • U = { u 1 , u 2 , … , u n } \mathcal{U} = \{u_1, u_2, \dots, u_n\} U={u1,u2,,un}是输入的未标注语料(token序列)。
    • k k k是上下文窗口的大小,即预测当前词 u i u_i ui时,使用前 k k k个词( u i − k , … , u i − 1 u_{i-k}, \dots, u_{i-1} uik,,ui1)作为上下文。
    • P ( u i ∣ u i − k , … , u i − 1 ; Θ ) P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) P(uiuik,,ui1;Θ) 是模型根据前 k k k个词预测 u i u_i ui的条件概率,其中参数 Θ \Theta Θ是通过训练得到的神经网络参数。
    • 通过最大化对数似然,模型被训练以最小化预测和真实词之间的差距,这个过程通常通过随机梯度下降(SGD)进行。

2. Transformer解码器结构(公式2):

公式2描述了模型的架构,采用了多层的Transformer解码器。Transformer通过自注意力机制来捕捉上下文依赖关系,并对输入序列进行编码。

初始嵌入层:

h 0 = U W e + W p h_0 = U W_e + W_p h0=UWe+Wp

  • 解释
    • U = ( u − k , … , u − 1 ) U = (u_{-k}, \dots, u_{-1}) U=(uk,,u1) 是上下文窗口中输入序列的词向量。
    • W e W_e We是词嵌入矩阵,用于将输入的token转换为词向量。
    • W p W_p Wp是位置嵌入矩阵,提供每个token的位置编码,用于捕捉词序信息。
Transformer块:

h l = transformer_block ( h l − 1 ) ∀ i ∈ [ 1 , n ] h_l = \text{transformer\_block}(h_{l-1}) \quad \forall i \in [1, n] hl=transformer_block(hl1)i[1,n]

  • 解释
    • h l h_l hl表示第 l l l层的Transformer输出。
    • 每一层 h l h_l hl是通过前一层 h l − 1 h_{l-1} hl1经过Transformer块(自注意力和前馈网络)的处理得到的。
    • 共有 n n n层,每一层都通过类似的操作进行。
输出层:

P ( u ) = softmax ( h n W e T ) P(u) = \text{softmax}(h_n W_e^T) P(u)=softmax(hnWeT)

  • 解释
    • h n h_n hn是最后一层的输出,经过词嵌入矩阵的转置 W e T W_e^T WeT变换后,再通过Softmax函数计算每个词的概率分布。
    • 这个概率分布用于预测输出的目标词 u u u,Softmax确保输出的各个词的概率和为1。

总结:

  • 该模型采用无监督的方式进行预训练,利用大规模未标注语料数据,通过最大化词序列的条件概率来训练语言模型。
  • 预训练的模型架构基于Transformer解码器,通过多层自注意力机制和位置编码来有效捕捉上下文信息,并使用Softmax输出目标词的概率分布。

2. 有监督微调

在这里插入图片描述

1. 微调任务的目标函数(公式3):

P ( y ∣ x 1 , … , x m ) = softmax ( h l m W y ) P(y \mid x^1, \dots, x^m) = \text{softmax}(h_l^m W_y) P(yx1,,xm)=softmax(hlmWy)

  • 解释
    • x 1 , … , x m x^1, \dots, x^m x1,,xm 是输入的token序列。
    • h l m h_l^m hlm表示输入序列经过预训练模型(如Transformer)的最后一层输出的激活值(即特征表示)。
    • W y W_y Wy是用于预测目标标签 y y y的线性层的参数矩阵。
    • Softmax 函数将线性层的输出转换为每个类的概率分布,用于分类任务中的标签预测。

2. 最大化目标函数(公式4):

L 2 ( C ) = ∑ ( x , y ) log ⁡ P ( y ∣ x 1 , … , x m ) L_2(C) = \sum_{(x, y)} \log P(y \mid x^1, \dots, x^m) L2(C)=(x,y)logP(yx1,,xm)

  • 解释
    • 这是监督学习的目标函数,模型通过最大化预测标签 y y y 的对数概率来微调模型参数。
    • C C C是标注数据集,包含输入序列 x x x和相应的标签 y y y
    • 目标是最大化所有样本的对数似然,确保模型在监督任务中的准确性。

3. 辅助目标函数(公式5):

L 3 ( C ) = L 2 ( C ) + λ ∗ L 1 ( C ) L_3(C) = L_2(C) + \lambda \ast L_1(C) L3(C)=L2(C)+λL1(C)

  • 解释
    • 为了提高监督学习的效果,模型还结合了语言建模的辅助目标,即无监督的语言建模损失( L 1 L_1 L1 )和监督任务的损失( L 2 L_2 L2)相结合。
    • λ \lambda λ 是用于平衡两个目标的权重参数。
    • 这样做的好处是:可以通过语言模型的任务帮助监督任务更好地泛化,同时加快收敛速度。这种辅助目标在之前的研究中已证明可以有效提高性能。

总结:

  • 监督微调阶段,模型利用预训练好的参数,结合带标签的数据来优化预测性能。
  • 通过最大化预测标签 y y y的对数概率,模型适应特定任务。
  • 引入语言建模作为辅助任务,有助于提升模型的泛化能力和训练效率。

3. 奖励建模

在这里插入图片描述

这个段落介绍了 奖励模型(Reward Modeling, RM) 在 InstructGPT 模型中的训练方式,具体描述了模型如何从 监督微调模型(SFT) 继续优化以输出奖励分数(reward score),并通过 比较(comparison)来训练这个奖励模型。

核心内容解释:

  1. 奖励模型的基础

    • 奖励模型的训练是基于监督微调模型(SFT)进一步改进的。为了输出奖励,SFT 模型的最后一层(unembedding layer)被移除,留下的是可以根据给定的提示(prompt)和回应(response)输出一个标量的奖励值。
    • 为了节省计算资源,他们使用了一个 6B 参数的奖励模型(RM),因为较大规模的 175B 参数奖励模型虽然理论上可能更准确,但实际训练时表现不稳定,且不适合作为 RL 的值函数。
  2. Stiennon et al. (2020) 的方法

    • 数据集:奖励模型通过比较训练,数据集包含两个模型生成的输出(response),并根据这些生成的结果做出比较。
    • 损失函数:训练中使用了 交叉熵损失(cross-entropy loss),这些比较的标签是人类标注员根据两者的优劣给出的。交叉熵损失衡量了两个结果中哪一个应该被人类标注员优先选择,实际优化的是两者奖励分数的对数几率(log odds)。
  3. 加速比较收集过程

    • 在标注过程中,给标注员展示了 K = 4 到 9 个不同的生成回应,标注员需要对这些生成的结果进行排名。这会生成 ( K 2 ) \binom{K}{2} (2K)对比较,即每个标注任务中有 K 个回应时,将产生 K 中选取 2 个的组合数的比较对数。比如 K = 9 K = 9 K=9 时,会生成 ( 9 2 ) = 36 \binom{9}{2} = 36 (29)=36 对比较。
    • 为了避免过度拟合,研究者决定不将所有比较对一起训练(因为不同的比较对之间存在强相关性),而是仅从每个提示中抽取一对比较结果作为一个训练样本。这种方式更加 计算高效,因为对于每个完成的任务只需要一次前向传播(forward pass),而不是处理所有 ( K 2 ) \binom{K}{2} (2K) 的比较对。
  4. 奖励模型的损失函数

    奖励模型的损失函数定义如下:
    loss ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) ∼ D [ log ⁡ ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] \text{loss}(\theta) = -\frac{1}{\binom{K}{2}} \mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \left( \sigma \left( r_\theta(x, y_w) - r_\theta(x, y_l) \right) \right) \right] loss(θ)=(2K)1E(x,yw,yl)D[log(σ(rθ(x,yw)rθ(x,yl)))]

    • σ ( ⋅ ) \sigma(\cdot) σ()sigmoid 函数,用于将差值映射到 [0, 1] 区间,用来表示某一结果被标注员认为更优的概率。
    • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型对于给定提示 x x x和生成的回应 y y y所输出的奖励分数。
    • y w y_w yw y l y_l yl分别是优胜和劣胜的生成结果。
    • D D D是人类标注员比较的结果数据集。
    • 1 ( K 2 ) \frac{1}{\binom{K}{2}} (2K)1 是对所有比较的标准化处理。

解释

  • 损失函数的目标是最大化优胜结果 y w y_w yw 比劣胜结果 y l y_l yl更受偏好的概率(通过两者奖励分数差异的 sigmoid 值来实现)。这个函数通过最小化损失来优化奖励模型,使得奖励模型能够更准确地给出与人类标注员偏好一致的分数。
  1. 防止过拟合的解释(脚注 5)
    • 如果每一个比较对都被视为一个单独的数据点,那么每个生成的回应可能在训练中会得到 K − 1 K-1 K1 次更新,从而导致模型过拟合。而研究人员发现,模型过度训练甚至只需一个 epoch 就会过拟合。为了解决这个问题,他们只对每个提示下的一对回应进行一次前向传播训练,从而避免过拟合。

总结:

  • 奖励模型的训练基于人类反馈,通过比较两个模型生成的回应来进行优化。该训练过程使用了 交叉熵损失函数,优化目标是让奖励模型尽可能地预测出哪个回应更符合人类标注员的偏好。
  • 通过只选取部分比较对进行训练(而不是所有组合对),减少了计算开销,并有效避免了模型过拟合。

4. 强化学习(PPO)

在这里插入图片描述

1. 强化学习(Reinforcement Learning, RL)

在这部分,模型使用了强化学习(RL)进行微调,采用了 PPO(Proximal Policy Optimization) 算法来优化策略。PPO 是一种策略梯度算法,常用于强化学习任务中,通过限制策略更新的步长来提高训练的稳定性。

2. 环境设置

强化学习的环境被设置为一个 多臂老虎机问题(bandit environment),该环境会随机给定提示(prompt),模型需要生成相应的回应。生成的回应会通过奖励模型(reward model)来打分并结束当前回合。

为了防止模型过度优化奖励模型,训练过程中在每个 token 的输出时,加入了 KL 惩罚项,这项惩罚的来源是监督微调模型(SFT, Supervised Fine-Tuned Model)。

3. PPO-ptx 模型

为了提高模型的泛化能力,研究者还尝试将 预训练梯度(pretraining gradients) 与 PPO 的梯度混合,构建了所谓的 PPO-ptx 模型。这种方法可以解决在某些公共 NLP 数据集上性能回退的问题。

他们使用了以下的 目标函数(Objective Function)

objective ( ϕ ) = E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] + γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \text{objective}(\phi) = \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] + \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] objective(ϕ)=E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]+γExDpretrain[log(πϕRL(x))]

4. 公式解读

第一部分:PPO 的核心目标

E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]

  • E ( x , y ) ∼ D π ϕ R L \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} E(x,y)DπϕRL表示在 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL 生成的分布 D π ϕ R L D_{\pi^{RL}_{\phi}} DπϕRL上的期望。
  • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型 r θ r_\theta rθ对生成的响应 y y y的奖励。
  • β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) βlog(πSFT(yx)πϕRL(yx))是 KL 散度惩罚项,惩罚 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL偏离 SFT 模型 π S F T \pi^{SFT} πSFT的程度,其中 β \beta β控制 KL 惩罚的权重。

解释
这部分的目标是最大化模型的奖励,同时通过 KL 惩罚防止策略 π ϕ R L \pi^{RL}_{\phi} πϕRL与监督微调模型 π S F T \pi^{SFT} πSFT偏离过远。惩罚项确保模型在强化学习时不走偏,保持与原本训练目标的相似性。

第二部分:预训练损失

γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] γExDpretrain[log(πϕRL(x))]

  • E x ∼ D pretrain \mathbb{E}_{x \sim D_{\text{pretrain}}} ExDpretrain 表示在预训练数据分布 D pretrain D_{\text{pretrain}} Dpretrain上的期望。
  • log ⁡ ( π ϕ R L ( x ) ) \log(\pi^{RL}_{\phi}(x)) log(πϕRL(x)) 是 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL生成的结果的对数概率。
  • γ \gamma γ是预训练损失项的权重,控制预训练数据与强化学习的结合程度。

解释
这一部分引入了预训练的损失,使得模型能够保持在大规模预训练数据上的表现,防止模型在强化学习过程中完全依赖于奖励模型而失去通用能力。通过设置 γ \gamma γ,我们可以平衡预训练损失和强化学习损失的影响。

5. 策略和符号说明

  • π ϕ R L \pi^{RL}_{\phi} πϕRL 是在强化学习中学习到的策略,参数为 ϕ \phi ϕ
  • π S F T \pi^{SFT} πSFT 是通过监督学习微调得到的策略,它代表了模型在强化学习之前的性能。
  • β \beta β是 KL 惩罚的权重系数,控制 RL 策略和 SFT 策略的偏离程度。
  • γ \gamma γ 是预训练损失的权重系数,控制预训练梯度在 PPO 优化中的作用。

总结:

在这篇文章中,InstructGPT 使用了强化学习中的 PPO(Proximal Policy Optimization) 进行策略优化,同时通过引入 KL 散度惩罚项 来确保 RL 策略与 SFT 策略不过度偏离。此外,预训练损失 通过一个额外的项加入到了目标函数中,以解决在某些 NLP 任务上的性能回退问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/56060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++和OpenGL实现3D游戏编程【连载15】——着色器初步

🔥C和OpenGL实现3D游戏编程【目录】 1、本节实现的内容 上一节我们介绍了通过VBO、VAO和EBO怎样将顶点发送到GPU显存,利用GPU与显存之间的高效处理速度,来提高我们的图形渲染效率。那么在此过程中,我们又可以通过着色器&#xff…

硬件开发笔记(三十一):TPS54331电源设计(四):PCB布板12V转5V电路、12V转3.0V和12V转4V电路

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/142757509 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

《OpenCV计算机视觉》—— 人脸检测

文章目录 一、人脸检测流程介绍二、用于人脸检测的关键方法1.加载分类器(cv2.CascadeClassifier())2.检测图像中的人脸(cv2.CascadeClassifier.detectMultiscale()) 三、代码实现 一、人脸检测流程介绍 下面是一张含有多个人脸的…

人工智能和机器学习之线性代数(一)

人工智能和机器学习之线性代数(一) 人工智能和机器学习之线性代数一将介绍向量和矩阵的基础知识以及开源的机器学习框架PyTorch。 文章目录 人工智能和机器学习之线性代数(一)基本定义标量(Scalar)向量&a…

【硬件模块】HC-08蓝牙模块

蓝牙模块型号 HC-08蓝牙模块实物图 HC-08蓝牙模块引脚介绍 STATE:状态输出引脚。未连接时,则为低电平。连接成功时,则为高电平。可以在程序中作指示引脚使用; RXD:串口接收引脚。接单片机的 TX 引脚(如…

Linux编辑器-vim的配置及其使用

vim是一种多模式的编辑器: 1.命令模式(默认模式):用户所有的输入都会当作命令,不会当作文本输入。 2.插入模式:写代码, 按「 i 」切换进入插入模式「 insert mode 」,按 “i” 进入…

SCI论文快速排版:word模板一键复制样式和格式【重制版】

关注B站可以观看更多实战教学视频:hallo128的个人空间SCI论文快速排版:word模板一键复制样式和格式:视频操作视频重置版2【推荐】 SCI论文快速排版:word模板一键复制样式和格式【重制版】 模板与普通文档的区别 为了让读者更好地…

【C++贪心 DFS】2673. 使二叉树所有路径值相等的最小代价|1917

本文涉及知识点 C贪心 反证法 决策包容性 CDFS LeetCode2673. 使二叉树所有路径值相等的最小代价 给你一个整数 n 表示一棵 满二叉树 里面节点的数目,节点编号从 1 到 n 。根节点编号为 1 ,树中每个非叶子节点 i 都有两个孩子,分别是左孩子…

苹果最新论文:LLM只是复杂的模式匹配 而不是真正的逻辑推理

大语言模型真的可以推理吗?LLM 都是“参数匹配大师”?苹果研究员质疑 LLM 推理能力,称其“不堪一击”!苹果的研究员 Mehrdad Farajtabar 等人最近发表了一篇论文,对大型语言模型 (LLM) 的推理能…

【数据结构笔记】搜索树

目录 二叉搜索树 结构特征 搜索 插入 删除 单子节点删除 双子节点删除 平衡二叉搜索树 AVL树 失衡与重平衡 插入失衡 删除失衡 “34”平衡重构 伸展树 逐层伸展 双层伸展 插入 删除 红黑树 结构特征 插入 自底向上的染色插入 双红修正 RR-1 RR-2 自顶…

超GPT3.5性能,无限长文本,超强RAG三件套,MiniCPM3-4B模型分享

MiniCPM3-4B是由面壁智能与清华大学自然语言处理实验室合作开发的一款高性能端侧AI模型,它是MiniCPM系列的第三代产品,具有4亿参数量。 MiniCPM3-4B模型在性能上超过了Phi-3.5-mini-Instruct和GPT-3.5-Turbo-0125,并且与多款70亿至90亿参数的…

RabbitMQ 入门(四)SpringAMQP五种消息类型

一、WorkQueue(工作消息队列) Work queues,也被称为(Task queues),任务模型。简单来说就是让多个消费者绑定到一个队列,共同消费队列中的消息。 当消息处理比较耗时的时候,可能生产消息的速度会远远大于…

Python自然语言处理之pyltp模块介绍、安装与常见操作案例

pyltp是哈尔滨工业大学社会计算与信息检索研究中心推出的一款基于Python封装的自然语言处理工具,它提供了哈工大LTP(Language Technology Platform)工具包的接口。LTP工具包以其强大的中文分词、词性标注、命名实体识别、依存句法分析等功能&…

Vue——Uniapp回到顶部悬浮按钮

代码示例 <template><view class"updata" click"handleup" :style"{bottom: bottomTypepx}" ><i class"iconfont icon-huidaodingbu"></i></view> </template><script> export default {n…

《机器学习与数据挖掘综合实践》实训课程教学解决方案

一、引言 随着信息技术的飞速发展&#xff0c;人工智能已成为推动社会进步的重要力量。作为人工智能的核心技术之一&#xff0c;机器学习与数据挖掘在各行各业的应用日益广泛。本方案旨在通过系统的理论教学、丰富的实践案例和先进的实训平台&#xff0c;帮助学生掌握机器学习…

C++ 比大小

//输入两个可能有前导 0 的大整数&#xff0c;a,b请输出他们谁大谁小#include <iostream> #include <string> #include <string.h> using namespace std; #define M 100005 int main() {char a[M], b[M];char *pa, *pb;pa a;pb b;cin >> a >> …

第十五届蓝桥杯C/C++学B组(解)

1.握手问题 解题思路一 数学方法 50个人互相握手 &#xff08;491&#xff09;*49/2 &#xff0c;减去7个人没有互相握手&#xff08;61&#xff09;*6/2 答案&#xff1a;1024 解题思路二 思路&#xff1a; 模拟 将50个人从1到50标号&#xff0c;对于每两个人之间只握一…

P327. 渔夫捕鱼算法问题

问题描述&#xff1a; A、B、C、D、E 这5个人合伙夜间捕鱼&#xff0c;凌晨时都已经疲惫不堪&#xff0c;于是各自在河边的树丛中找地方睡着了。第二天日上三竿时&#xff0c;A第一个醒来&#xff0c;他将鱼平分为5份&#xff0c;把多余的一条扔回河中&#xff0c;然后拿着自己…

【D3.js in Action 3 精译_034】4.1 D3 中的坐标轴的创建(中一)

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第一部分 D3.js 基础知识 第一章 D3.js 简介&#xff08;已完结&#xff09; 1.1 何为 D3.js&#xff1f;1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践&#xff08;上&#xff09;1.3 数据可…

FFmpeg的简单使用【Windows】--- 简单的视频混合拼接

实现功能 点击【选择文件】按钮在弹出的对话框中选择多个视频&#xff0c;这些视频就是一会将要混剪的视频素材&#xff0c;点击【开始处理】按钮之后就会开始对视频进行处理&#xff0c;处理完毕之后会将处理后的文件路径返回&#xff0c;并在页面展示处理后的视频。 视频所…