使用 DPO 微调 Llama 2

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)

  2. 用偏好标签标注数据

  3. 基于偏好数据训练奖励模型

  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示

  • chosen : 即针对给定提示的较优回答

  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:return {"prompt": ["Question: " + question + "\n\nAnswer: "for question in samples["question"]],"chosen": samples["response_j"], # rated better than k"rejected": samples["response_k"], # rated worse than j}dataset = load_dataset("lvwerra/stack-exchange-paired",split="train",data_dir="data/rl"
)
original_columns = dataset.column_namesdataset.map(return_prompt_and_responses,batched=True,remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(model, # 经 SFT 的基础模型model_ref, # 一般为经 SFT 的基础模型的一个拷贝beta=0.1, # DPO 的温度超参train_dataset=dataset, # 上文准备好的数据集tokenizer=tokenizer, # 分词器args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.10.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, # "meta-llama/Llama-2-7b-hf"quantization_config=bnb_config,device_map={"": 0},trust_remote_code=True,use_auth_token=True,
)
base_model.config.use_cache = False# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(r=script_args.lora_r,lora_alpha=script_args.lora_alpha,lora_dropout=script_args.lora_dropout,target_modules=["q_proj", "v_proj"],bias="none",task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(model=base_model,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,packing=True,max_seq_length=None,tokenizer=tokenizer,args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # location of saved SFT modellow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # same model as the main onelow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=script_args.beta,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

你可以在 这儿 找到我们的模型在训练过程的 WandB 日志,其中包含了 DPOTrainer 在训练和评估期间记录下来的以下奖励指标:

  • rewards/chosen (较优回答的奖励) : 针对较优回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/rejected (较劣回答的奖励) : 针对较劣回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/accuracy (奖励准确率) : 较优回答的奖励大于相应较劣回答的奖励的频率的均值

  • rewards/margins (奖励余裕值) : 较优回答的奖励与相应较劣回答的奖励二者之差的均值。

直观上讲,在训练过程中,我们希望余裕值增加并且准确率达到 1.0,换句话说,较优回答的奖励高于较劣回答的奖励 (或余裕值大于零)。随后,我们还可以在评估数据集上计算这些指标。

我们希望我们代码的发布可以降低读者的入门门槛,让大家可以在自己的数据集上尝试这种大语言模型对齐方法,我们迫不及待地想看到你会用它做哪些事情!如果你想试试我们训练出来的模型,可以玩玩这个 space: trl-lib/stack-llama。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/dpo-trl

原文作者: Kashif Rasul, Younes Belkada, Leandro von Werra

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54729.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python网络编程

文章目录 socket套接字客户端/服务模型linux文件描述符fdLinux网络IO模型详解网络服务器Apache VS Nginx生产者消费者-生成器版客户端/服务端-多线程版IO多路复用TCPServer模型异步IO多路复用TCPServer模型 socket套接字 套接字(socket)是抽象概念,表示T…

c++ qt--事件(第六部分)

c qt–事件(第六部分) 一.编辑伙伴,编辑顺序(按TAB进行切换) 1.编辑伙伴 此功能在设计界面如下的位置 1.设置伙伴关系 鼠标左键长按一个Label组件然后把鼠标移到另一个组件上 2.伙伴关系的作用 伙伴关系的作用就是…

HoudiniVex笔记_P26_RecursionBasics递归基础

原视频:https://www.youtube.com/playlist?listPLzRzqTjuGIDhiXsP0hN3qBxAZ6lkVfGDI Bili:Houdini最强VEX算法教程 - VEX for Algorithmic Design_哔哩哔哩_bilibili Houdini版本:19.5 1、概述 递归是一种直接或者间接地调用自身的算法&a…

猜数游戏-Rust版

cargo new guessing_game 创建项目 输入任意内容,并打印出来 main.rs: use std::io; // 像String这些类型都在预先导入的prelude里,如果要使用的不在prelude里,则需要显式导入fn main() { println!("猜数"); println!("…

37、springboot 为 spring mvc 提供的自动配置及对自动配置的一些自定义定制(大体思路)

springboot 为 spring mvc 提供的自动配置及对自动配置的一些自定义定制(大体思路) ★ Spring Boot主流支持两个MVC框架: Spring MVC(基于Servlet) Spring WebFlux(基于Reactive,属于响应式AP…

vcomp140.dll丢失的修复方法分享,电脑提示vcomp140.dll丢失修复方法

今天,我的电脑出现了一个奇怪的问题,打开某些程序时总是提示“找不到vcomp140.dll文件”。这个问题让我非常头疼,因为我无法正常使用电脑上的一些重要软件。为了解决这个问题,我在网上查找了很多资料,并尝试了多种方法…

使用Aircrack-ng进行无线网络破解

Aircrack-ng是一款流行的无线网络渗透测试工具,主要用于密码破解和网络分析。但是,请注意,仅在有合法授权的情况下使用这些工具。 以下是一个使用Aircrack-ng进行无线网络破解的示例,以及一些步骤和注意事项: 步骤&a…

2023年最新版Windows环境下|Java8(jdk1.8)安装教程

个人主页:平行线也会相交 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 平行线也会相交 原创 收录于专栏【JavaSE_primary】 jdk1.8的下载和使用总共分为3个步骤: jdk1.8的下载、jdk1.8的安装、配置环境变量。 目录 一、jdk1.8下载…

NP_hard

P类问题就是指那些计算机比较容易算出答案的问题。 NP类问题就是指那些已知答案以后计算机可以比较容易地验证答案的问题。 当谈论NP-hard问题时,可以使用旅行推销员问题(Traveling Salesman Problem,TSP)作为一个通俗易懂的例子…

IDEA常用配置之类Tab页多行显示

文章目录 IDEA常用配置之类Tab页多行显示 IDEA常用配置之类Tab页多行显示 默认在Idea中打开类过多,后面会隐藏显示,这里修改配置,将类设置为多行显示,方便查找已经打开的类 修改后显示样式

Unity——后期处理举例

Post Processing(后期处理)并不属于特效,但现代的特效表现离不开后期处理的支持。本文以眩光(Bloom)为例,展示一种明亮的激光的制作方法 1、安装后期处理扩展包 较新的Unity版本已经内置了新版的后期处理扩…

sftp命令 添加端口(亲测)

要在sftp命令中指定端口&#xff0c;请使用以下语法&#xff1a; sftp -oPort<port_number> <username><host> 其中&#xff0c;<port_number>是你要连接的SFTP服务器的端口号&#xff0c;<username>是登录SFTP服务器所使用的用户名&#xff0…

wxpython:wx.Menu 菜单示例

pip install wxpython4.2 wxPython-4.2.0-cp37-cp37m-win_amd64.whl (18.0 MB) Successfully installed wxpython-4.2.0 cd \Python37\Scripts wxdemo.exe 下载 wxPython-demo-4.2.0.tar.gz wxdocs.exe 下载 wxPython-docs-4.2.0.tar.gz 编写 test_wx_menu.py 如下 # -*- …

创建型模式-建造者模式

使用多个简单的对象一步一步构建成一个复杂的对象 主要解决&#xff1a;主要解决在软件系统中&#xff0c;有时候面临着"一个复杂对象"的创建工作&#xff0c;其通常由各个部分的子对象用一定的算法构成&#xff1b;由于需求的变化&#xff0c;这个复杂对象的各个部…

ARM DIY(三)板载串口和 LCD 调试

前言 今天焊接两大关键输入输出设备&#xff1a;串口和屏幕。 串口 串口部分使用 CP2102N 芯片&#xff08;USB 转 TTL&#xff09;&#xff0c;这样用一根数据线连接板子和 PC 就可以直接调试了。 焊接 CP2102 和 Type C 上电调试&#xff0c;串口可以正常输入输出。 看来…

使用Python写入数据到Excel:实战指南

在数据科学领域&#xff0c;Excel是一种广泛使用的电子表格工具&#xff0c;可以方便地进行数据管理和分析。然而&#xff0c;当数据规模较大或需要自动化处理时&#xff0c;手动操作Excel可能会变得繁琐。此时&#xff0c;使用Python编写程序将数据写入Excel文件是一个高效且便…

【办公类-16-01-02】2023年度上学期“机动班下午代班的排班表——跳过周三、节日和周末”(python 排班表系列)

背景需求&#xff1a; 2023年第一学期&#xff08;2023年9-2024年1月&#xff09;&#xff0c;我又被安排为“机动班”&#xff0c;根据新学期的校历&#xff0c;手动推算本学期的机动班的带班表 排版原则 1、班级数量&#xff1a;共有6个班级&#xff0c;循环滚动 2、每周次…

安装启动yolo5教程

目录 一、下载yolo5项目 二、安装miniconda&#xff08;建议不要安装在C盘&#xff09; 三、安装CUDA 四、安装pytorch 五、修改配置参数 六、修改电脑参数 七、启动项目 博主硬件&#xff1a; Windows 10 家庭中文版 一、下载yolo5项目 GitHub - ultralytics/yolov5:…

使用EF Core更新与修改生产数据库

使用EF Core的Code First&#xff0c;在设计阶段&#xff0c;直接使用Database.EnsureCreated()和EnsureDeleted()可以快速删除、更新最新的数据结构。由于没有什么数据&#xff0c;删除的风险非常低。但是对于已经投入生产的数据库&#xff0c;这个方法就绝对不可行了。 考虑…

【C进阶】指针(一)

大家好&#xff0c;我是深鱼~ 【前言】&#xff1a; 指针的主题&#xff0c;在初阶指针章节已经接触过了&#xff0c;我们知道了指针的概念&#xff1a; 1.指针就是个变量&#xff0c;用来存放地址&#xff0c;地址的唯一标识一块内存空间&#xff08;指针变量&#xff09;&a…