使用DPO微调大模型Qwen2详解

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。但传统的RLHF比较复杂,且还需要奖励模型,故DPO方法被提出,其将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。
且huggingface的trl库已经集成了dpo,使用起来非常方便。

本次以QWEN2(蹭热点),为例进行训练,分别介绍单轮对话的DPO多轮对话的DPO,对应的数据集分别如下(均在huggingface):

  • 单轮:lvwerra/stack-exchange-paired
  • 多轮:trl-internal-testing/hh-rlhf-helpful-base-trl-style

通过DPO微调模型大概可以简单的分为两个步骤:
1、将数据处理成所需格式。
2、使用DPOTrainer进行训练

两种形式的dpo代码已集成至github上的大模型训练框架,并做了详细的使用解释及代码位置说明,可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

项目包括一个每个人都可以以此为基础构建自己的开源大模型训练框架流程、支持主流模型使用deepspeed进行Lora、Qlora、DPO等训练、主流模型的chat template模版、以及一些tricks的从零实现模块。欢迎大家star 共同学习!:

单轮对话构建DpoDataset

标准的DpoDataset数据集,最终的数据集对象应包含这3个条目。条目应命名为:

  • prompt
  • chosen
  • rejected

官方示例

单轮官方示例如下:

dpo_dataset_dict = {"prompt": ["hello","how are you","What is your name?","What is your name?","Which is the best programming language?","Which is the best programming language?","Which is the best programming language?",],"chosen": ["hi nice to meet you","I am fine","My name is Mary","My name is Mary","Python","Python","Java",],"rejected": ["leave me alone","I am not fine","Whats it to you?","I dont have a name","Javascript","C++","C++",],
}

多轮示例为上述提到的数据集,大家可以大概看一下是长这个样子:
在这里插入图片描述

从头开始构建

比较简单的方式是套用官方给的示例,如下所示,只需要将数据集映射为上述我们提到的prompt、chosen、rejected格式,此时传递给DPOTrainer的数据是未编码之前的,DPOTrainer中会自动的给我们进行编码。注意下面并没有添加对应模型的chat template,根据不同模型的template可以在return_prompt_and_responses中自行添加即可。

def return_prompt_and_responses(samples) -> Dict[str, str, str]:return {"prompt": ["Question: " + question + "\n\nAnswer: "for question in samples["question"]],"chosen": samples["response_j"], # rated better than k"rejected": samples["response_k"], # rated worse than j}dataset = load_dataset("lvwerra/stack-exchange-paired",split="train",data_dir="data/rl"
)
original_columns = dataset.column_namesdataset.map(return_prompt_and_responses,batched=True,remove_columns=original_columns
)dpo_trainer = DPOTrainer(model, # 经 SFT 的基础模型model_ref, # 一般为经 SFT 的基础模型的一个拷贝beta=0.1, # DPO 的温度超参train_dataset=dataset, # 上文准备好的数据集tokenizer=tokenizer, # 分词器args=training_args, # 训练参数,如: batch size, 学习率等
)

为了便于我们理解数据处理细节及进行一些魔改操作,我们可以从头自己构建一个DpoDataset。
首先,深入DPOTrainer源码可以看到其数据处理操作主要是在tokenize_row函数,如下所示,
在这里插入图片描述
最终返回的是一个batch字典字段,代码部分如下所示:
在这里插入图片描述
在这里插入图片描述
最终返回的字段为:

dict(prompt_input_ids,prompt_attention_mask,chosen_input_ids,chosen_attention_mask,chosen_labels,rejected_input_ids,rejected_attention_mask,rejected_labels,)

主要的__getitem__代码如下所示:

    def __getitem__(self, item):data = self.data_list[item]data = json.loads(data)  # 将json格式转换为python字典prompt =  data['prompt']chosen = data['chosen']rejected = data['rejected']# 对prompt进行编码prompt = self.user_format.format(content=prompt, stop_token=self.tokenizer.eos_token)if self.system_format is not None:system = self.systemif system is not None:system_text = self.system_format.format(content=system)input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)prompt_input_ids = input_ids + self.tokenizer.encode(prompt, add_special_tokens=False)else:prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)# 进行回答的input id编码chosen = self.assistant_format.format(content=chosen, stop_token=self.tokenizer.eos_token)rejected = self.assistant_format.format(content=rejected, stop_token=self.tokenizer.eos_token)chosen_input_ids = self.tokenizer.encode(chosen, add_special_tokens=False)rejected_input_ids = self.tokenizer.encode(rejected, add_special_tokens=False)# 对最大长度进行截断longer_response_length = max(len(chosen_input_ids), len(rejected_input_ids))# keep end 对prompt截断if len(prompt_input_ids) + longer_response_length > self.max_seq_length:max_prompt_length = max(self.max_prompt_length, self.max_seq_length - longer_response_length)prompt_input_ids = prompt_input_ids[-max_prompt_length:]# 如果还不符合则回答截断if len(prompt_input_ids) + longer_response_length > self.max_seq_length:chosen_input_ids = chosen_input_ids[: self.max_seq_length - len(prompt_input_ids)]rejected_input_ids = rejected_input_ids[: self.max_seq_length - len(prompt_input_ids)]chosen_labels = [-100] * len(prompt_input_ids) + chosen_input_idschosen_input_ids = prompt_input_ids + chosen_input_idsrejected_labels = [-100] * len(prompt_input_ids) + rejected_input_idsrejected_input_ids = prompt_input_ids + rejected_input_idsassert len(chosen_labels) == len(chosen_input_ids)assert len(rejected_labels) == len(rejected_input_ids)inputs = dict(prompt_input_ids=prompt_input_ids,prompt_attention_mask=[1] * len(prompt_input_ids),chosen_input_ids=chosen_input_ids,chosen_attention_mask=[1] * len(chosen_input_ids),chosen_labels=chosen_labels,rejected_input_ids=rejected_input_ids,rejected_attention_mask=[1] * len(rejected_input_ids),rejected_labels=rejected_labels,)return inputs

适配DPOTrainer

构建完dataset后要适配DPOTrainer,可以看到其需要使用dataset进行一个map操作,这也就是DPOTrainer自动给我们处理数据的入口。
在这里插入图片描述
在我们自建的Dataset类中添加一个map函数映射会self即可:

    def map(self, func, **kwargs):return self

多轮对话构建DpoDataset

多轮对话构建我们这里就不自己去写了,直接采用DPOTrainer中自带的数据处理即可。
部分代码如下所示:

        if tokenizer.chat_template is None:tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"train_dataset = load_dataset(data_files=args.train_data_path, path='json')def process(row):row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)return rowtrain_dataset = train_dataset.map(process)train_dataset = train_dataset['train']return train_dataset

完整代码集成至github项目中,具体可参见:

开始Qwen2-8B 多轮和单轮DPO训练

使用DPOTrainer即可开始训练

trainer = DPOTrainer(model,ref_model=None,args=train_args,train_dataset=train_dataset,tokenizer=tokenizer,peft_config=peft_config)
dpo_trainer.train()
dpo_trainer.save_model()

总结

两种形式的dpo代码已集成至github上的大模型训练框架,并做了详细的使用解释及代码位置说明,可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25671.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BabylonJS 6.0文档 Deep Dive 动画(四):通过动画排序制作卡通片

一种最为直接的方法是为每个动画剪辑(Animatin Clip)指定开始时间,最终形成一个卡通动画(Cartoon)。 1. 设计 1.1 概述 动画的脚本如下: 摄像机显示了一栋带门的建筑物。摄像机靠近门并停止。门打开&am…

掌控数据流:深入解析 Java Stream 编程

Java 8 引入了一种新的抽象称为流(Stream),它可以让你以一种声明的方式处理数据。Java 8 Stream API 可以极大提高 Java 程序员的生产力,使代码更简洁,更易读,并利用多核架构进行外部迭代。这里将详细介绍 …

【NoSQL数据库】Redis简介

Redis Redis简介 Redis关系型数据库和非关系型数据库Redis 简介redis速度快的原因 Redis 配置Linux 源码安装redis命令工具 关系型数据库和非关系型数据库 关系型数据库(Relational Database)和非关系型数据库(Non-Relational Database&…

重学Spring总结

1、Spring框架的诞生 文章目录 1、Spring框架的诞生1、BeanFactory 快速入门1.1、BeanFactory完成了loC思想的实现:1)导入Spring相关的依赖:2)定义Uservice接口及其UserviceImpl实现类;3)创建Bean的配置资源文件,文件名最好为&…

新材料正不断推动模具3D打印行业发展

随着工业4.0的浪潮席卷全球,模具制造行业也迎来了技术革新的新纪元。3D打印技术以其独特的制造优势,正逐渐在模具制造领域崭露头角。然而,要实现模具3D打印技术的广泛应用,高性能的打印材料是不可或缺的关键因素。 材料是模具3D打…

【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道

【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道 大家好 我是寸铁👊 总结了一篇【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道✨ 喜欢的小伙伴可以点点关注 💝 前言🍎 在计算机科学中,数据结…

【内存管理】内存管理概述

文章目录 内存管理硬件结构早期内存的使用方法分段分页逻辑地址,线性地址(intel架构)虚拟地址物理地址结构图 虚拟地址到物理地址的转换内存管理总览系统调用vm_area_struct缺页中断伙伴系统slab分配器页面回收反向映射KSMhuge page页迁移内存…

[AI Google] 使用 Gemini 取得更多成就:试用 1.5 Pro 和更多智能功能

总结 Google 正在为超过 35 种语言的 Gemini Advanced 订阅者推出 Gemini 1.5 Pro。此次更新包括 100 万个 token 的上下文窗口、改进的数据分析功能和增强的多模态图像理解。新功能包括用于自然对话的 Gemini Live、先进的规划工具和可定制的 Gems。更新还集成了更多 Google …

【MySQL】(基础篇五) —— 排序检索数据

排序检索数据 本章将讲授如何使用SELECT语句的ORDER BY子句,根据需要排序检索出的数据。 排序数据 还是使用上一节中的例子,查询employees表中的last_name字段 SELECT last_name FROM employees;输出结果: 发现其输出并没有特定的顺序。其实&#xf…

电子电气架构 --- 信息安全测试模糊测试

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己,无利益不试图说服别人,是精神上的节…

Objective-C的初始化方法中,应该如何读写属性

除非有明确的原因需要使用setter, getter, 否则总是应该直接访问, 也就是直接使用实例变量(也称为 iVar)来读写数据 理由: 避免子类覆盖setter方法的影响:若在初始化方法中使用setter方法, 使用此方法实例化子类, 可能会调用子类…

纯理论容器实现的原理

近期在复习容器的原理,希望这篇文章可以帮助到大家。 一、什么是容器? 容器本质上就是主机上的一个进程。这个进程拥有自己的用户空间并且和主机共享内核空间。 容器内的进程可以通过系统调用与内核进行交互,使用内核提供的各种功能和资源。…

刷代码随想录有感(99):动态规划——使用最小花费爬楼梯

题干&#xff1a; 代码&#xff1a; class Solution { public:int minCostClimbingStairs(vector<int>& cost) {vector<int>dp(cost.size() 1);dp[0] 0;dp[1] 0;for(int i 2; i < cost.size(); i){dp[i] min(dp[i - 1] cost[i - 1], dp[i - 2] cost…

Leetcode 力扣114. 二叉树展开为链表 (抖音号:708231408)

给你二叉树的根结点 root &#xff0c;请你将它展开为一个单链表&#xff1a; 展开后的单链表应该同样使用 TreeNode &#xff0c;其中 right 子指针指向链表中下一个结点&#xff0c;而左子指针始终为 null 。展开后的单链表应该与二叉树 先序遍历 顺序相同。 示例 1&#xf…

KUKA机器人中断编程详细教程1—了解中断

在公众号查看更多内容。 在KUKA机器人编程与调试中&#xff0c;经常会用到中断编程。通过中断实现机器人暂停&#xff0c;或者停止当前的动作进入中断后的程序中接着运行&#xff0c;以此来满足实际的调试要求。 1、中断的概念 ①当出现诸如输入等定义的事件时&#xff0c;…

【算法篇】求最长公共前缀JavaScript版本

题目描述 给你一个大小为 n 的字符串数组 strs &#xff0c;其中包含n个字符串 , 编写一个函数来查找字符串数组中的最长公共前缀&#xff0c;返回这个公共前缀。 数据范围&#xff1a; 数据范围:0<n<5000&#xff0c;0<len(strsi)< 5000 进阶:空间复杂度 O(1)&a…

Typora Markdown编辑器 for Mac v1.8.10 安装

Mac分享吧 文章目录 效果一、准备工作二、开始安装1、双击运行软件&#xff0c;将其从左侧拖入右侧文件夹中&#xff0c;等待安装完毕2. 应用程序显示软件图标&#xff0c;表示安装成功 三、运行调试1、修改主题2、显示文档列表&#xff0c;如下图3、查看版本信息 **安装完成&…

【PR2019】怎样批量添加转场效果及修改默认持续时间

一&#xff0c;设置“交叉溶解”效果到所有素材 选择效果&#xff0c;右击“将所选过渡设置为默认过渡”&#xff1a; 框选所有素材&#xff0c;“Ctrl D”&#xff1a; 每个素材中间有有了交叉溶解的效果&#xff1a; 二&#xff0c;修改效果属性 2.1&#xff0c;单个修…

北航第五次数据结构与程序设计编程题复习

北航第五次数据结构与程序设计编程题复习 树叶节点遍历&#xff08;树-基础题&#xff09;计算器&#xff08;表达式计算-表达式树实现&#xff09;服务优化词频统计&#xff08;树实现&#xff09; 树叶节点遍历&#xff08;树-基础题&#xff09; 【问题描述】 从标准输入中…

CTFHUB-SQL注入-报错注入

目录 报错注入概述 报错注入的原理 报错注入的步骤 报错注入的常用函数 实战案例 结论 方法1&#xff1a;updatexml函数 查看数据库名 查看表名 查看表中数据 方法2&#xff1a;extractvalue函数 查看数据库名 查看数据库中的表名 查看表中字段名 查看表中数据 报…