Training language models to follow instructions with human feedback 论文阅读

论文原文:https://arxiv.org/pdf/2203.02155

论文简介

语言模型越大并不意味着它能更好的理解用户的意图,因此在这篇论文中,展示了根据人的反馈对模型进行微调,使得语言模型能够在各种人物上更好的理解用户的意图。在评估中,1.3B参数的InstructGPT模型的输出比175B GPT-3的输出更受欢迎,尽管参数少了100倍。此外,InstructGPT模型虽然在公共的数据上的效果有所降低,但是真实性和减少有害方面生成的能力提升。论文表明,尽管InstructGPT仍然会犯一些简单的错误,但根据人类反馈进行微调是能够理解人类意图的一个有效的方式和方向。
**相当于是,OpenAI提出了”align“的概念,希望模型的输出与人类的意图”对齐“,其用的方法是RLHF(Reinforcement Learning from Human Feedback)基于人类反馈的强化学习。**

方法和实验细节

在这里插入图片描述

Collect demonstration data, and train a supervised policy. (收集范例数据,并以有监督方式训练)

我们的打标签者提供了输入提示分布(prompt distribution)上所需行为的范例(有关此分布的详细信息,请参阅第 3.2 节)。 然后,我们使用有监督学习在该数据集上微调预训练的 GPT-3 模型。这部分就是根据prompts,也就是写的各种问题,进行标注,将prompts和标注的对话作为人工标注的数据集,对预训练的GPT-3进行有监督微调

Collect comparison data, and train a reward model. (收集比较数据,训练奖励模型)

我们收集了模型输出之间比较的数据集,其中打标记者根据输入标明了他们更喜欢的输出。 然后我们训练奖励模型来预测人类偏好的输出。用上一步得到的SFT模型生成各种问题的答案,再对这些答案进行比较(排序式)标注,如D>C>A=B,基于这个标注数据集,在去掉最后的嵌入层的SFT模型基础上进行有监督学习训练一个RM(reward model),这样使用模型来模仿标注者进行打分

Optimize a policy against the reward model using PPO. (使用PPO针对奖励模型优化策略)

我们使用RM奖励模型的输出作为标量奖励。 我们使用 PPO 算法微调监督策略以优化此奖励。
步骤2和步骤3可以不断迭代; 收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。 在实践中,我们的大部分比较数据来自监管的学习,也有一些来自我们的PPO学习。用上一步的RM模型进行打分,然后分数就可以用强化学习来对SFT模型进行优化

数据集

打标签者提供了输入提示分布(prompt distribution)上所需行为的范例,根据论文所说,为了训练第一个InstructGPT模型,打标签者需要自己编写提示,分为三种:

  • Plain:只是要求标记者提出一个任意的任务,同时确保任务具有足够的多样性。
  • Few-shot:要求标注者提出一条指令,以及针对该指令的多个查询/相应对。
  • User-based:在OpenAI API的候补名单申请中陈述了许多用例,要求标注者提出与这些用例相对应的提示。
    根据这些提示,生成了三个用于微调过程的不同数据集:(1)SFT数据集,带有用于训练SFT模型的打标签者范例数据,(2)RM数据集,带有用于训练的模型已被打标签者分了等级的数据,(3)PPO数据集,没有任何人工标签,用于RLHF微调的输入。SFT数据集包含大约13k个训练提示数据(来自API和标记者编写),RM数据集有33k个训练提示数据(来自API和打标记者编写),PPO数据集有31k个训练提示数据(仅来自API)。
    在这里插入图片描述
    上表显示了API提示(特别是RM数据集)的用例类别的分布,大多数用例都是生成的,而不是分类或QA。在表二中展示了一些说明性提示(由研究人员编写,以模仿提交给InstructGPT模型的提示类型)。

任务

训练任务来自两个来源:(1)由标注者编写的提示数据集和(2)提交给API上的早期InstructGPT模型的提示数据集。这些提示非常多样化,包括生成、问答、对话、摘要、提取和其他自然语言任务。数据集超过96%是英语。
对于每个自然语言提示,任务通常是通过自然语言指令直接指定的(例如”写一个关于聪明青蛙的故事“),但也可以通过少数例子间接指定(例如给出两个青蛙故事的例子,并提示模型生成一个新的)或隐含的连续(例如提供一个关于青蛙的故事的开始)。在每种情况下,我们都要求标注者尽最大努力推断出写提示的用户的意图,并要求他们跳过任务非常不清楚的输入(相当于当任务非常不清楚的时候,可以跳过回答,避免答非所问)。此外,在我们提供给他们的指示和他们的最佳判断的指导下,标注者还需考虑到隐含的意图,如回应的真实性,以及潜在的有害输出,如有偏见或有毒的语言。

模型

我们从GPT-3预训练语言模型开始。这些模型是在广泛分布的互联网数据上进行训练的,可以适应广泛的下游任务,但行为特征不佳。从这些模型开始,我们用三种不同的技术训练模型:

  • 有监督微调(SFT——Supervised fine-tuning),我们使用监督学习对标记器演示中的GPT-3进行微调。我们训练了16个epoch,使用余弦学习率衰减,0.2的残差dropout。我们根据验证集上的RM分数进行最终的SFT模型选择。我们发现SFT模型在1个epoch后对验证损失上过拟合,然而我们发现尽管存在过拟合,但更多epochs的训练有助于RM分数和人类偏好评级。(尽管这个SFT模型训练更多的epoch会产生过拟合,但是这是为了得到后续的RM模型的初始化模型,对RM模型有帮助,并不是直接使用这个SFT模型,所以过拟合没关系
  • 奖励建模(RM——Reward model),从移除了最后的非嵌入层的SFT模型开始(GPT模型最后的softmax层是用于得到每个词的概率,去掉softmax层以后,增加一个线性层来投影,将所有词的输出投影到一个值上面,即输出一个标量的分数),我们训练了一个模型来接收提示和相应,并输出标量奖励。在本文中,我们只使用6B RM,这样可以节省大量计算,而且我们发现175B RM训练可能不稳定,因此不太适合用作RL(Reinforcement learning)中的值函数。RM在同一输入的两个模型输出之间进行比较的数据集上训练。他们使用交叉熵损失,将比较作为标签——奖励的差异代表了人类标记者更喜欢一种反应的对数几率。
    为了加速分等级数据的收集,我们向标签提供者提供 K = 4 K=4 K=4 K = 9 K=9 K=9之间的任何排名相应。这会为显示给标签者的每个提示生成 ( K 2 ) = C K 2 \binom{K}{2}=C_K^2 (2K)=CK2比较。由于分等级数据在每个标记任务中都非常相关,我们发现,如果我们简单地将分等级数据混洗到一个数据集中,在数据集上的一次遍历会导致奖励模型过拟合。相反,我们将每个提示的所有 ( K 2 ) \binom{K}{2} (2K)比较数据作为单个批处理元素进行训练。这在计算上要高效得多,因为它只需要每次完成一次RM的前向传递(而不是超过 ( K 2 ) \binom{K}{2} (2K)次前向传递),而且因为它不在过拟合,大大提高了验证准确性和日志损失。
    具体来说,奖励模型的损失函数为(这里使用的是排序中最常见的pairwise ranking loss,成对排名损失):
    在这里插入图片描述
    这里的 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y)是表示prompt x x x和相应 y y y在参数为 θ \theta θ的奖励模型下的奖励值, y w y_w yw是在prompt x x x下生成的一对响应 y w y_w yw y l y_l yl中更受欢迎的那一个, D D D是比较的数据集。每一个排名对 y i , y j y_i,y_j yi,yj的损失是 − l o g ( σ ( y i − y j ) ) -log(\sigma(y_i-y_j)) log(σ(yiyj)),换成奖励函数就是 − l o g ( σ ( r θ ( x , y w ) ) − r θ ( x , y l ) ) -log(\sigma(r_{\theta}(x,y_w))-r_{\theta}(x,y_l)) log(σ(rθ(x,yw))rθ(x,yl)),然后共 C K 2 C_K^2 CK2个排序对,所以期望除以它。
    目标是最小化这个loss,也就是最大化这两个奖励的差值, l o g ( σ ) log(\sigma) log(σ)最开始的时候是把生成的每个输出对都作为单独的数据混洗到数据集中,这样的话就需要超过 ( K 2 ) \binom{K}{2} (2K)次前向传递,而且输出对之间有重复,这样容易过拟合,所以将所有的输出对都统一作为单个批处理元素进行训练,这样的话就只需要 K K K次前向传递,因为奖励模型只需要算出9个奖励。之所以取 K = 9 K=9 K=9,是因为考虑到人工标注的时候,很大一部分是花在读懂这个prompt,所以在 K = 4 K=4 K=4 K = 9 K=9 K=9之间,只多了不到一倍的时间,但是标注的数据由6变成了36,多了6倍
    最后,由于RM损失对于奖励的变化是不变的,我们使用偏差对奖励模型进行归一化,以便在进行RL之前,标记器演示的平均得分为0。
  • 强化学习(RL——Reinforcement learning),我们使用PPO在我们的环境中微调了SFT模型。该模型是一个bandit环境,它呈现随机的客户提示并期望对提示的响应。给定提示和相应,它会产生由奖励模型确定的奖励并结束情节。此外,我们在每个token上上添加了SFT模型的每个token的KL惩罚,以减轻奖励模型的过度优化。从RM初始化值函数。我们称这些模型为PPO。
    我们还尝试将预训练梯度混合到PPO梯度中,以修复公共NLP数据集上的性能回归。我们称这些模型为”PPO-ptx“。我们在RL训练中最大化以下组合目标函数:
    在这里插入图片描述
    其中 π Θ R L \pi_{\Theta}^{RL} πΘRL是学习到的RL策略, π S F T \pi^{SFT} πSFT是有监督训练的模型, D p r e t r a i n D_{pretrain} Dpretrain是预训练分布。KL奖励系数 β \beta β,预训练损失系数 γ \gamma γ分别控制KL惩罚和预训练梯度的强度。对于”PPO“模型, γ \gamma γ被设置为0,除非另有说明,本文中的InstructGPT指的是PPO-ptx模型。对于上面说的31k个prompts数据集 D D D,都使用当前的RL模型,也就是RL策略 π θ R L \pi_{\theta}^{RL} πθRL,输出 y y y,然后用RM模型得到分数 r θ ( x , y ) ,目标函数是希望这个分数最大化 r_{\theta}(x,y),目标函数是希望这个分数最大化 rθ(x,y),目标函数是希望这个分数最大化然后根据这个目标函数,更新RL模型,然后再用RM模型计算得分,反复迭代。
    目标函数中还有两项,在此分别解释一下, β l o g ( π Θ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_{\Theta}^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πΘRL(yx)/πSFT(yx))是正则项,这是PPO的主要思想,随着模型的更新,RL产生的输出 y y y和原始的 S F T SFT SFT模型输出的 y y y会逐渐不一样,即数据分布( y ∣ x y|x yx)的差异会越来越大, R L RL RL的输出可能会不准,所以论文在loss里加入了一个KL散度 KL ( P ∥ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) = ∫ P ( x ) log ⁡ ( P ( x ) Q ( x ) ) d x \text{KL}(P \parallel Q) = \sum_{x} P(x) \log \left(\frac{P(x)}{Q(x)}\right)= \int P(x) \log \left(\frac{P(x)}{Q(x)}\right)\, dx KL(PQ)=xP(x)log(Q(x)P(x))=P(x)log(Q(x)P(x))dx,用于描述一个概率分布相对于另一个概率分布的非对称性差异,相当于用这个散度来正则,希望RLSFT的输出分布不要偏太远,因为是最大化目标函数,所以要最小化KL散度需要在前面加一个负号。
    γ E x D p r e t r a i n [ l o g ( π Θ R L ( x ) ) ] \gamma E_x ~ D_{pretrain}[log(\pi_{\Theta}^{RL}(x))] γEx Dpretrain[log(πΘRL(x))],由于前两项目标函数只和人类排序部分有关,所以训练出来会导致模型仅仅对排序的结果较好,而在最终任务通用NLP任务上性能会下降,所以论文在loss中加入了GPT-3预训练模型的目标函数, D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT-3的预训练数据中采样 x x x,然后输入RL模型得到输出概率 π Θ R L ( x ) \pi_{\Theta}^{RL}(x) πΘRL(x),这样相当于是GPT-3本身的损失函数。

    总的来说,如果 γ = 0 \gamma=0 γ=0就是一个PPO函数,否则就是一个PPO加上一个GPT-3的目标函数的结合成为RL模型的目标函数,也就是PPO-ptx
    在这里插入图片描述

讨论

论文提出,本文使用的”对齐技术“——RLHF,是用于对齐人类系统的一个重要方法。与预训练相比,增加模型对齐的成本是适中的(仅仅标注几万条prompt数据),与训练GPT-3的花费相比(海量的各种数据),只占一小部分。上述结果也表明,RLHF在使语言模型更加helpful(真实性和无害性是被隐式优化了)方面非常有效,甚至比模型增加100倍更有效。所以,在自然语言领域,研究alignment可能比训练更大规模的模型更具性价比。
align也有争议,就是到底要align人类到什么地步,是用户让做什么就做什么,还是要理解用户更深层的、内在的一些东西。此外最后的RL模型也不是必要的,如果在第一步多标数据,在GPT-3微调,步骤会变得简单,可能更加实用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/28699.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

树莓派4B学习笔记11:PC端网线SSH连接树莓派

今日继续学习树莓派4B 4G:(Raspberry Pi,简称RPi或RasPi) 本人所用树莓派4B 装载的系统与版本如下: 版本可用命令 (lsb_release -a) 查询: Opencv 版本是4.5.1: 今日学习使用网线连接树莓派,网线可以提供更…

使用 C# 学习面向对象编程:第 8 部分

抽象方法 亲爱的读者,本文是 OOP 的第四大支柱,也是最后一大支柱。对于 OOP 初学者来说,这很容易让人困惑。因此,我们用非常简单的语言提供了一个示例。 “抽象用于管理复杂性。无法创建抽象类的对象。抽象类用于继承。” 例如…

降噪领夹麦克风哪个牌子好?揭秘无线领夹麦克风哪个降噪好

相信很多新手视频创作者都有一个疑问:为什么别人的视频或者直播音质这么清晰,几乎没什么噪音呢?其实最主要的就是麦克风的原因,相机或手机内置的麦克风是无法提供高质量的音频记录以及很好的指向性的。 想要拍摄出来的视频作品拥有…

每一个男人都曾有一个机器人的梦想

每一个男人都曾有一个机器人的梦想 我也有 每一个男人都曾有一个机器人的梦想。对于我来说,这个梦想始于童年时代,那时变形金刚风靡一时,几乎所有80后的孩子都为之疯狂。我是80后中的一员,那时候的科技还远没有如今这般发达&#…

《现代通信原理与技术》码间串扰和​​​​​​​无码间串扰的眼图对比实验报告

实 验:码间串扰和无码间串扰的眼图对比实验报告 摘 要: 在数字通信系统中,码间串扰(Inter-Symbol Interference, ISI)是影响信号质量和系统性能的重要因素之一。本实验通过MATLAB软件生成并对比了受码间串扰影响和未…

华为昇腾异构计算架构CANN及AI芯片简介

异构计算架构CANN 异构计算架构CANN(Compute Architecture for Neural Networks)是华为针对AI场景推出的异构计算架构,向上支持多种AI框架,包括MindSpore、PyTorch、TensorFlow等,向下服务AI处理器与编程,…

Open To Buy(OTB)计划:零售业者的库存管理利器

在当今快速变化的服装市场中,如何高效、精准地进行商品管理成为了服装企业竞争的关键。OTB(Open-to-Buy)作为一种有效的商品管理方法,在企业管理中扮演着至关重要的角色。它基于预算、商品计划以及市场需求等多维度因素&#xff0…

Android开发系列(二)Jetpack Compose 之Text控件

Jetpack Compose是一种全新的声明式UI框架,用于构建Android应用程序。Jetpack Compose Text控件是Compose中用于显示文本的基本UI组件。 Text是一个可组合函数,函数声明如下所示。 Composable fun Text(text: String,modifier: Modifier Modifier,colo…

AGI 远不止 ChatGPT!一文入门 AGI 通识及应用开发

AI 大语言模型进入爆发阶段 2022 年 12 月 ChatGPT 突然爆火,原因是其表现出来的智能化已经远远突破了我们的常规认知。虽然其呈现在使用者面前仅仅只是一个简单的对话问答形式,但是它的内容化水平非常强大,甚至在某些方面已经超过人类了&am…

k8s上使用ConfigMap 和 Secret

使用ConfigMap 和 Secret 实验目标: 学习如何使用 ConfigMap 和 Secret 来管理应用的配置。 实验步骤: 创建一个 ConfigMap 存储应用配置。创建一个 Secret 存储敏感信息(如数据库密码)。在 Pod 中挂载 ConfigMap 和 Secret&am…

计算机视觉全系列实战教程:(八)图像变换-点运算、灰度变换、直方图变换

图像变换:点运算、灰度变换、直方图变换 1.点运算(1)What(2)Why 2.灰度变换(1)What(2)Why(作用)(3)Which(有哪些灰度变换) 3.直方图修正(1)直方图均衡化 1.点运算 (1)What 通过点运算,输出图像的每个像素的灰度值仅仅取决于输入图像中相对应…

【招联消费金融股份】有限公司2024年5月18日【算法开发岗暑期实习】一面试经验分享

招联消费金融股份有限公司2024年5月18日面试经验分享 面试流程:共30多分钟,先3分钟自我介绍,然后细细介绍简历上面的论文和实习信息。问题1:扩散模型的noise schedule有什么研究。问题2:有哪些常见的数学分布问题3&…

新版嘎嘎快充互联互通系统配置文档

宝塔环境配置 登录宝塔账号,安装nginx、mysql5.7、php7.2、supervisor、redisphp安装扩展: 1)安装swooleloader72 将嘎嘎官方提供的swoole_loader_72_nts.so文件上传到 /www/server/php/72/lib/php/extensions/no-debug-non-zts-20170718…

Spring的事务步骤

一、事务处理方案: Spring框架中提供的事务处理方案:一共有两种: 1.适合中小项目使用的, 注解方案: 注解的方式做事务用起来简单,灵活,方便,中小型项目中用它比较方便&#xff0c…

基于STM32和人工智能的智能水质监测系统

目录 引言环境准备智能水质监测系统基础代码实现:实现智能水质监测系统 4.1 数据采集模块4.2 数据处理与分析4.3 控制系统4.4 用户界面与数据可视化应用场景:智能水质管理与优化问题解决方案与优化收尾与总结 1. 引言 随着环境保护意识的提高&#xf…

【C/C++】【学生成绩管理系统】深度剖析

可接各类C/C管理系统课设 目录 实现功能 部分1:系统设置和主菜单 1. 引入头文件 2. 定义结构体 3. 函数声明 4. 主函数 部分2:添加学生信息 部分3:删除学生信息 部分4:修改学生信息 部分5:查询学生信息 部分…

数组元素的内存地址计算【数据结构与算法C#版】

数组元素被存储在连续的内存空间中,这意味着计算数组元素的内存地址非常容易。给定数组内存地址(首 元素内存地址)和某个元素的索引,我们可以使用下方图 所示的公式计算得到该元素的内存地址,从而直接 访问该元素。 观…

电源小白入门学习11——反激电源电路原理

电源小白入门学习11——反激电源、正激电源 隔离电源变压器介绍反激电源 前面我们学习了BUCK、BOOST、BUCK-BOOST 等各种各样的DCDC变换器,但是他们都有一共同的特点,即能量的传输路径时一个完整的通路,输入与输出之间不存在电气隔离&#xf…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 字符串变换(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 📎在线评测链接 字符串变换(100分) 🌍 评测功能需要订阅专栏后私信联系清隆…

【Unity实战篇】| 快速制作一个简易时钟,包括2D和3D时钟

前言 【Unity实战篇】| 快速制作一个时钟,包括2D和3D时钟一、2D时钟制作1.1 钟表盘制作1.2 指针制作1.3 钟表搭建1.4 设置时钟的中心点1.5 时钟旋转逻辑 二、3D时钟制作2.1 搭建表盘和指针2.2 调整指针的位置和节点2.3 时钟旋转逻辑 总结 前言 时钟 这个东西想必不…