使用DPO微调Llama2

简介

0a8c71720b4dabdfcaae8d442bd45bcc.jpeg

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)
  2. 用偏好标签标注数据
  3. 基于偏好数据训练奖励模型
  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示
  • chosen : 即针对给定提示的较优回答
  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(
    model, # 经 SFT 的基础模型
    model_ref, # 一般为经 SFT 的基础模型的一个拷贝
    beta=0.1, # DPO 的温度超参
    train_dataset=dataset, # 上文准备好的数据集
    tokenizer=tokenizer, # 分词器
    args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.1 到 0.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name, # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # location of saved SFT model
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # same model as the main one
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/52636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程 1、创建服务2、修改权限3、启动服务4、测试 机器:Nvidia Jetson Xavier系统:ubuntu 18.04 最近在使用symfony的console组件,需要执行一个后台的php进程,并且…

C++ Day5

目录 一、静态成员 1.1 概念 1.2 格式 1.3 银行账户实例 二、类的继承 2.1 目的 2.2 概念 2.3 格式 2.4 继承方式 2.5 继承中的特殊成员函数 2.5.1 构造函数 2.5.2析构函数 2.5.3 拷贝构造函数 2.5.4拷贝赋值函数 总结: 三、多继承 3.1 概念 3.2 格…

k8s--基本概念理解

必填字段 在要创建的 Kubernetes 对象的文件中.yaml,您需要设置以下字段的值: apiVersion- 您使用哪个版本的 Kubernetes API 创建此对象 kind- 你想创建什么样的对象 metadata- 有助于唯一标识对象的数据,包括name字符串、UID和可选namesp…

大数据领域都有什么发展方向

近年来越来越多的人选择大数据行业,大数据行业前景不错薪资待遇好,各大名企对于大数据人才需求不断上涨。 大数据从业领域很宽广,不管是科技领域还是食品产业,零售业等都是需要大数据人才进行大数据的处理,以提供更好…

Python调用文心一言的API

最近申请了文心一言的key,然后尝试调用了一下文心一言,这里使用一个简单的方式来调用文心一言: pip install paddle-pipelinesfrom pipelines.nodes import ErnieBotapi_key "your apply key" secret_key "your apply secr…

【python】Leetcode(primer-set)

文章目录 78. 子集(集合的所有子集)90. 子集 II(集合的所有子集)792. 匹配子序列的单词数(判断是否为子集)500. 键盘行(集合的交集)409. 最长回文串(set) 更多…

测试先行:探索测试驱动开发的深层价值

引言 在软件开发的世界中,如何确保代码的质量和可维护性始终是一个核心议题。测试驱动开发(TDD)为此提供了一个答案。与传统的开发方法相比,TDD鼓励开发者从用户的角度出发,先定义期望的结果,再进行实际的开发。这种方法不仅可以确保代码满足预期的需求,还可以在整个开…

IBM Db2 笔记

目录 1. IBM Db2 笔记1.1. 常用命令1.2. 登录命令行模式 (Using the Db2 command line processor)1.3. issue1.3.1. db2: command not found/SQL10007N Message "-1390" could not be retreived. Reason code: "3".1.3.2. db2 修改 dbm cfg 的时候报 SQL50…

C++ Day4

目录 一、拷贝赋值函数 1.1 作用 1.2 格式 二、匿名对象 2.2 格式 三、友元 3.1作用 3.2格式 3.3 种类 3.4 全局函数做友元 3.5类做友元 3.6 成员函数做友元 3.7注意 四、常成员函数和常对象 4.1 常成员函数 4.1.1格式 示例: 4.2 常对象 作用&…

shell学习笔记(详细整理)

主要介绍:主要是常用变量,运算符,条件判断,流程控制,函数,常用shell工具(cut,sed,awk,sort)。 一. Shell概述 程序员为什么要学习Shell呢? 1)需要看懂运维人员编写的Shell程序。 2…

恒运资本:什么是股票分红和基金分红?两者有什么区别?

出资者在进行股票出资和基金出资的时分,能够参与股票分红或许基金分红,可是许多新手对分红不是很了解。那么什么是股票分红和基金分红?两者有什么区别?下面就由恒运资本为大家分析: 什么是股票分红和基金分红&#xff…

LeGO-Loam代码解析(二)--- Lego-LOAM的地面点分离、聚类、两步优化方法

1 地面点分离剔除方法 1.1 数学推导 LeGO-LOAM 中前端改进中很重要的一点就是充分利用了地面点,那首先自然是提取 对地面点的提取。 如上图,相邻的两个扫描线束的同一列打在地面上如 点所示,他们的垂直高度差 ,水平距离差 ,计算垂直高度差和水平高度差…

n5173b是德科技keysight N5173B信号发生器

产品概述 是德科技/安捷伦N5173B EXG模拟信号发生器 当您需要平衡预算和性能时,是德科技N5173B EXG微波模拟信号发生器是经济高效的选择。它提供解决宽带滤波器、放大器、接收机等参数测试的基本信号。执行基本LO上变频或CW阻塞,低成本覆盖13、20、31.…

Dynamic CRM开发 - 实体字段(三)计算字段

在 Dynamic CRM开发 - 实体字段(一)中提到了实体字段的属性字段类型:有简单、计算、汇总三种,本篇幅通过一个示例讲解计算字段。 有这样一个需求:根据用户填写的出生日期,计算年龄。 1、创建一个“出生日期”字段,时间类型即可。 2、创建一个计算字段“年龄”,如下图…

ChatGPT:ChatGPT 的发展史,ChatGPT 优缺点以及ChatGPT 在未来生活中的发展趋势和应用

目录 1.ChatGPT 是什么 2. ChatGPT 的发展史 3.ChatGPT 优缺点 4.ChatGPT 在未来生活中的发展趋势和应用 5.ChatGPT经历了几个版本 1.ChatGPT 是什么 ChatGPT 是一个在线聊天机器人,可以与使用者进行语义对话和提供帮助。它可以回答各种问题,提供建议…

数据库(DQL,多表设计,事务,索引)

目录 查询数据库表中数据 where 条件列表 group by 分组查询 having 分组后条件列表 order by 排序字段列表 limit 分页参数 多表设计 一对多 多对多 一对一 多表查询 事物 索引 查询数据库表中数据 关键字:SELECT 中间有空格,加引…

按斤称的C++散知识

一、多线程 std::thread()、join() 的用法&#xff1a;使用std::thread()可以创建一个线程&#xff0c;同时指定线程执行函数以及参数&#xff0c;同时也可使用lamda表达式。 #include <iostream> #include <thread>void threadFunction(int num) {std::cout <…

SpringBoot整合FFmpeg进行视频分片上传(Linux)

SpringBoot整合FFmpeg进行视频分片上传&#xff08;Linux&#xff09; 上传的核心思路&#xff1a; 1.将文件按一定的分割规则&#xff08;静态或动态设定&#xff0c;如手动设置20M为一个分片&#xff09;&#xff0c;用slice分割成多个数据块。 2.为每个文件生成一个唯一标识…

给计算机专业准大学生的一些建议

随着互联网的发展&#xff0c;计算机专业也越来越受到了广泛的关注。在这个信息爆炸的时代&#xff0c;互联网为计算机专业的学生提供了更多的机遇和挑战。那么&#xff0c;对于目前互联网形式给计算机专业准大学生&#xff0c;我想分享以下几个建议。 一、掌握基础知识 在学…

【C++】—— C++11新特性之 “右值引用和移动语义”

前言&#xff1a; 本期&#xff0c;我们将要的介绍有关 C右值引用 的相关知识。对于本期知识内容&#xff0c;大家是必须要能够掌握的&#xff0c;在面试中是属于重点考察对象。 目录 &#xff08;一&#xff09;左值引用和右值引用 1、什么是左值&#xff1f;什么是左值引用…