使用DPO微调Llama2

简介

0a8c71720b4dabdfcaae8d442bd45bcc.jpeg

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)
  2. 用偏好标签标注数据
  3. 基于偏好数据训练奖励模型
  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示
  • chosen : 即针对给定提示的较优回答
  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(
    model, # 经 SFT 的基础模型
    model_ref, # 一般为经 SFT 的基础模型的一个拷贝
    beta=0.1, # DPO 的温度超参
    train_dataset=dataset, # 上文准备好的数据集
    tokenizer=tokenizer, # 分词器
    args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.1 到 0.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name, # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # location of saved SFT model
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # same model as the main one
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/52636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程

linux篇---使用systemctl start xxx启动自己的程序|开机启动|守护进程 1、创建服务2、修改权限3、启动服务4、测试 机器:Nvidia Jetson Xavier系统:ubuntu 18.04 最近在使用symfony的console组件,需要执行一个后台的php进程,并且…

C++ Day5

目录 一、静态成员 1.1 概念 1.2 格式 1.3 银行账户实例 二、类的继承 2.1 目的 2.2 概念 2.3 格式 2.4 继承方式 2.5 继承中的特殊成员函数 2.5.1 构造函数 2.5.2析构函数 2.5.3 拷贝构造函数 2.5.4拷贝赋值函数 总结: 三、多继承 3.1 概念 3.2 格…

大数据领域都有什么发展方向

近年来越来越多的人选择大数据行业,大数据行业前景不错薪资待遇好,各大名企对于大数据人才需求不断上涨。 大数据从业领域很宽广,不管是科技领域还是食品产业,零售业等都是需要大数据人才进行大数据的处理,以提供更好…

【python】Leetcode(primer-set)

文章目录 78. 子集(集合的所有子集)90. 子集 II(集合的所有子集)792. 匹配子序列的单词数(判断是否为子集)500. 键盘行(集合的交集)409. 最长回文串(set) 更多…

测试先行:探索测试驱动开发的深层价值

引言 在软件开发的世界中,如何确保代码的质量和可维护性始终是一个核心议题。测试驱动开发(TDD)为此提供了一个答案。与传统的开发方法相比,TDD鼓励开发者从用户的角度出发,先定义期望的结果,再进行实际的开发。这种方法不仅可以确保代码满足预期的需求,还可以在整个开…

恒运资本:什么是股票分红和基金分红?两者有什么区别?

出资者在进行股票出资和基金出资的时分,能够参与股票分红或许基金分红,可是许多新手对分红不是很了解。那么什么是股票分红和基金分红?两者有什么区别?下面就由恒运资本为大家分析: 什么是股票分红和基金分红&#xff…

LeGO-Loam代码解析(二)--- Lego-LOAM的地面点分离、聚类、两步优化方法

1 地面点分离剔除方法 1.1 数学推导 LeGO-LOAM 中前端改进中很重要的一点就是充分利用了地面点,那首先自然是提取 对地面点的提取。 如上图,相邻的两个扫描线束的同一列打在地面上如 点所示,他们的垂直高度差 ,水平距离差 ,计算垂直高度差和水平高度差…

n5173b是德科技keysight N5173B信号发生器

产品概述 是德科技/安捷伦N5173B EXG模拟信号发生器 当您需要平衡预算和性能时,是德科技N5173B EXG微波模拟信号发生器是经济高效的选择。它提供解决宽带滤波器、放大器、接收机等参数测试的基本信号。执行基本LO上变频或CW阻塞,低成本覆盖13、20、31.…

Dynamic CRM开发 - 实体字段(三)计算字段

在 Dynamic CRM开发 - 实体字段(一)中提到了实体字段的属性字段类型:有简单、计算、汇总三种,本篇幅通过一个示例讲解计算字段。 有这样一个需求:根据用户填写的出生日期,计算年龄。 1、创建一个“出生日期”字段,时间类型即可。 2、创建一个计算字段“年龄”,如下图…

数据库(DQL,多表设计,事务,索引)

目录 查询数据库表中数据 where 条件列表 group by 分组查询 having 分组后条件列表 order by 排序字段列表 limit 分页参数 多表设计 一对多 多对多 一对一 多表查询 事物 索引 查询数据库表中数据 关键字:SELECT 中间有空格,加引…

【C++】—— C++11新特性之 “右值引用和移动语义”

前言: 本期,我们将要的介绍有关 C右值引用 的相关知识。对于本期知识内容,大家是必须要能够掌握的,在面试中是属于重点考察对象。 目录 (一)左值引用和右值引用 1、什么是左值?什么是左值引用…

PID直观感受简述

0、仿真控制框图 1、增加p的作用(增加响应)P 2、增加I的作用(消除稳差)PI 3、增加D的作用(抑制波动)PID 加入对噪声很敏 4、综合比对

linux中定时器的使用

在Linux中&#xff0c;可以使用timer_create、timer_settime和timer_delete等函数来创建和管理定时器。下面是一个简单的示例程序&#xff0c;演示如何在Linux中使用定时器&#xff1a; #include <stdio.h> #include <stdlib.h> #include <signal.h> #inclu…

STL---list

目录 1. list的介绍及使用 1.1 list的介绍 1.2 list的使用注意事项 2.list接口介绍及模拟实现 2.1构造​编辑 2.2容量 2.3修改 3.list迭代器 4.迭代器失效 5.模拟实现 6.vector和list的区别 1. list的介绍及使用 1.1 list的介绍 list的文档介绍 1. list是可以在常…

数据库第十七课-------ETL任务调度系统的安装和使用

作者前言 &#x1f382; ✨✨✨✨✨✨&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f367;&#x1f382; ​&#x1f382; 作者介绍&#xff1a; &#x1f382;&#x1f382; &#x1f382; &#x1f389;&#x1f389;&#x1f389…

Jenkins配置远程服务器SSH Server流程

说明&#xff1a;以阿里云轻量应用服务器&#xff0c;本文介绍如何在Jenkins中配置远程服务器&#xff0c;Jenkins安装参考这篇文章&#xff1b; 第一步&#xff1a;启动服务 首先&#xff0c;启动Jenkins容器&#xff0c;进入Jenkins管理后台&#xff0c;点击系统配置&#…

echarts 的dataZoom滑块两端文字被遮挡

问题&#xff1a; 期望&#xff1a; 解决方案&#xff1a; 1&#xff1a;调整宽度&#xff08;4版本的没有width属性&#xff09; 2. 参考&#xff1a;echarts图标设置dataZoom拖拽时间轴时自动调整两侧文字的位置_datazoom 位置_乌栖曲的博客-CSDN博客 设置文字的定位 cons…

物联网(IoT)安全挑战与解决方案: 分析物联网设备面临的安全威胁,以及如何设计和管理安全的IoT生态系统

第一章&#xff1a;引言 随着科技的飞速发展&#xff0c;物联网&#xff08;IoT&#xff09;作为连接世界的桥梁&#xff0c;已经成为现代社会不可或缺的一部分。然而&#xff0c;随着IoT设备数量的不断增加&#xff0c;其安全问题也日益显著。本文将深入探讨IoT领域面临的安全…

暄桐展览| 我们桐学有自己的习作展(1)

林曦老师《从书法之美到生活之美》的第五阶课程《静定的滋养2021》已告一段落。570天的用功&#xff0c;桐学们的技艺都有了水涨船高的进益。      无论书法课&#xff08;全阶和五阶&#xff09;还是国画课&#xff0c;暄桐都有一套完整系统的教学体系&#xff0c;也会在桐…

Java | IDEA中Netty运行多个client的方法

想要运行多个client但出现这种提示&#xff1a; 解决方法 1、打开IDEA&#xff0c;右上角找到下图&#xff0c;并点击 2、勾选