[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练

当前关于LLM的共识

大型语言模型(LLM)使 NLP 中微调模型的过程变得更加复杂。最初,当 ChatGPT 等模型首次出现时,最主要的方法是先训练奖励模型,然后优化 LLM 策略。从人类反馈中强化学习(RLHF)极大地推动了NLP的发展,并将NLP中许多长期面临的挑战抛在了一边。基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。

然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

现在主流的LLM,比如chatglm、chinese-alpaca,主要进行了三步操作:

Step1:知识学习,CLM,大规模语料库上的预训练,本步的模型拥有续写的功能

Step2:知识表达,指令微调,在指令数据上进行微调,本步骤可以使用Lora等节省显存的方式,本模型可以听懂人类指令并进行回答的功能

Step3:偏好学习,RLHF或本文所提的DPO,可以让模型的输出更符合人类偏好,通俗说就是同样一句话,得调教的让模型输出人类喜欢的表达方式,好比高情商的人说话让人舒服

第二步,还是多多少少学习了一点知识,第三步则几乎不学知识,只学表达方式了。

RLHF太耗时耗力了,得提前训练好RewardModel,然后PPO阶段,得加载4个模型,2个推理,2个训练,实在是太不友好了。

下图是SFT+RLHF的过程,对应上文的Step2和Step3,主要包括指令微调模型、训练奖励模型和PPO优化。

现在大多数目前开源的LLM模型都只做了前2步:预训练和指令微调。

而其中原因就是第3步人类反馈强化学习(RLHF)实现起来很困难:

1.需要人类反馈数据(很难收集)
2.奖励模型训练(很难训练)
3. PPO强化学习微调(不仅很耗资源,而且也很难训练)

但是能不能不要最后一步呢,一般来说还是有RLHF比较好,有主要有以下几个原因:

  1. 提高安全性和可控性;
  2. 改进交互性;
  3. 克服数据集偏差;
  4. 提供个性化体验;
  5. 符合道德规范;
  6. 持续优化和改进。

RLHF使得ChatGPT这样的大型对话模型既具备强大能力,又能够接受人类价值观的指导,生成更智能、安全、有益的对话回复。这是未来可信赖和可解释AI的重要发展方向。

所以这一步还是非常重要。那如何解决人类反馈强化学习(RLHF)训练这个难题呢?

DPO (Differentiable Policy Optimization) 算法

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

DPO 是为实现对 LLM 的精确控制而引入的一种方法。从人类反馈强化学习(RLHF)的基础是训练奖励模型,然后使用近端策略优化(PPO)使语言模型的输出与人类的偏好相一致。这种方法虽然有效,但既复杂又不稳定。DPO 将受限奖励最大化问题视为人类偏好数据的分类问题。这种方法稳定、高效、计算量小。它无需进行奖励模型拟合、大量采样和超参数调整。

DPO(Direct Preference Optimization)是一种直接偏好优化算法,它与PPO(Proximal Policy Optimization)优化的目标相同。主要思路是:

1.定义policy模型(策略模型)和reference模型(参考模型),Policy模型是需要训练的对话生成模型,reference模型是给定的预训练模型或人工构建的模型。

2.对于给定prompt,计算两模型对正样本和负样本的概率,正样本是人类选择的回复,负样本是被拒绝的回复。

3.通过两个模型概率的差值构建DPO损失函数,惩罚policy模型对正样本概率的下降和负样本概率的上升。通过最小化DPO损失进行模型训练。

相比之下DPO就很友好,只需要加载2个模型,其中一个推理,另外一个训练,直接在偏好数据上进行训练即可:

DPO 拒绝有害问题 实战部分

数据集

数据集其实就是标准的RLHF奖励模型的训练集,下载地址在这

Anthropic/hh-rlhf · Datasets at Hugging Face

dikw/hh_rlhf_cn · Datasets at Hugging Face

其样式就是:一个context,一个选择的正样本,一个拒绝的负样本。希望这些样本能够让LLM 尽可能生成用户选择的无害回复,而不要生成有害的回复。

微调代码
下方这段代码实现了基于DPO (Differentiable Policy Optimization) 的对话模型微调。主要步骤包括:

  1. 加载预训练语言模型(这里使用llama-2-7b)并准备量化训练,采用int4量化的+少量lora 参数。
  2. 定义参考模型(int4量化的模型),也使用同样的预训练模型。
  3. 加载Helpful/Harmless数据集,并转换成所需格式。
  4. 定义DPO训练参数,包括batch size,学习率等。
  5. 定义DPO训练器,传入policy模型,参考模型,训练参数等。
  6. 进行DPO微调训练。
  7. 保存微调后的模型,只保存量lora 参数。

关键点:

1. 使用DPO损失函数实现安全性约束的模型训练。不需要额外在训练一个奖励模型。
2. 这也导致整个训练过程只需要策略模型和参考模型 2个LLM模型,不需要额外的显存去加载奖励模型。
3. 整个训练过程策略模型和参考模型可以进行4int的模型量化 + 少量的lora 参数

综上,这段代码对预训练语言模型进行DPO微调,以实现安全可控的对话生成

#!/usr/bin/env python
# coding: utf-8from typing import Dictimport torch
from datasets import Dataset, load_dataset
from trl import DPOTrainer
import bitsandbytes as bnbfrom transformers import TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import (LoraConfig,get_peft_model,prepare_model_for_kbit_training
)output_dir1 = "./dpo_output_dir1"
output_dir2 = "./dpo_output_dir2"base_model = "/home/work/llama-2-7b"###准备训练数据
dataset = load_dataset("json", data_files="./dpo_dataset/harmless_base_cn_train.jsonl")
train_val = dataset["train"].train_test_split(test_size=2000, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]def extract_anthropic_prompt(prompt_and_response):final = ""for sample in prompt_and_response:final += sample["role"] + "\n" + sample["text"]final += "\n"return finaldef get_hh(dataset, split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.The dataset is converted to a dictionary with the following structure:{'prompt': List[str],'chosen': List[str],'rejected': List[str],}Prompts should be structured as follows:\n\nHuman: <prompt>\n\nAssistant:Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:."""dataset = datasetif sanity_check:dataset = dataset.select(range(min(len(dataset), 1000)))def split_prompt_and_responses(sample) -> Dict[str, str]:prompt = extract_anthropic_prompt(sample["context"])return {"prompt": prompt,"chosen": sample["chosen"]["role"] + "\n" + sample["chosen"]["text"],"rejected": sample["rejected"]["role"] + "\n" + sample["rejected"]["text"],}return dataset.map(split_prompt_and_responses)train_dataset = get_hh(train_data, "train", sanity_check=True)
eval_dataset = get_hh(val_data, "test", sanity_check=True)def find_all_linear_names(model):# cls = bnb.nn.Linear8bitLtcls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split('.')lora_module_names.add(names[0] if len(names) == 1 else names[-1])if 'lm_head' in lora_module_names:  # needed for 16-bitlora_module_names.remove('lm_head')return list(lora_module_names)def print_trainable_parameters(model):"""Prints the number of trainable parameters in the model."""trainable_params = 0all_param = 0for _, param in model.named_parameters():all_param += param.numel()if param.requires_grad:trainable_params += param.numel()print(f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}")tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 trainingbnb_4bit_compute_dtype = "float16"
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_4bit_quant_type = "nf4"
use_nested_quant = Falsebnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type=bnb_4bit_quant_type,bnb_4bit_compute_dtype=compute_dtype,bnb_4bit_use_double_quant=use_nested_quant,
)model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True,quantization_config=bnb_config,device_map="auto")
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)modules = find_all_linear_names(model)
config = LoraConfig(r=8,lora_alpha=16,lora_dropout=0.05,bias="none",target_modules=modules,task_type="CAUSAL_LM",
)model = get_peft_model(model, config)
print_trainable_parameters(model)###定义参考模型
model_ref = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True,quantization_config=bnb_config,device_map="auto")
###定义dpo训练参数
training_args = TrainingArguments(per_device_train_batch_size=1,max_steps=100,remove_unused_columns=False,gradient_accumulation_steps=2,learning_rate=3e-4,evaluation_strategy="steps",output_dir="./test",
)###定义dpo训练器
dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=0.1,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,
)
###训练
dpo_trainer.train()
###模型保存
dpo_trainer.save_model(output_dir1)dpo_trainer.model.save_pretrained(output_dir2)
tokenizer.save_pretrained(output_dir2)

训练过程

其中看出加载了2遍int4量化的模型到显存中,需要训练的策略模型只有一部分lora参数,而参考模型就是原始模型本身.

模型保存

保存下来的参数也就是lora参数,这部分lora 参数就学会了如何拒绝回答有害问题。

至此,我们就学会了如何利用使用DPO +Qlora 实现在完成RLHF的实战。

使用场景

核心原则:偏好数据集中的good/bad response都是和SFT model的训练数据同分布的,也可以说模型是可以生成good/bad response的。

场景1

已有一个SFT model,为了让它更好,对它的output进行偏好标注,然后使用DPO进行训练,这是最正常的使用场景,但是偏好数据集确实避免不了的

场景2

场景1的改进版本,偏好标注不由人来做,而是让gpt4或者一个reward model来标注好坏,至于reward model怎么来,就各凭本事吧

场景3

没有SFT model只有偏好数据集,那就先在偏好数据即中的进行训练,然后在进行DPO的训练。先SFT就是为了符合上文的核心原则

OpenAI独家绝技RLHF也被开源超越啦?!DPO让小白轻松玩转RLHF![已开源] - 知乎 (zhihu.com)

RLHF中的「RL」是必需的吗?有人用二进制交叉熵直接微调LLM,效果更好 - 知乎 (zhihu.com)

直接偏好优化:你的语言模型其实是一个奖励模型 - 知乎 (zhihu.com)

消费级显卡搞定RLHF——DPO算法+QLora微调LLM拒绝有害问题回答实战 - 知乎 (zhihu.com)

使用 DPO 微调 Llama 2 - 知乎 (zhihu.com)

DPO(Direct Preference Optimization):LLM的直接偏好优化 - 知乎 (zhihu.com)

DPO: Direct Preference Optimization 论文解读及代码实践 - 知乎 (zhihu.com)GitHub - mzbac/llama2-fine-tune: Scripts for fine-tuning Llama2 via SFT and DPO.

DPO——RLHF 的替代之《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》论文阅读 - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/83994.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

优化系统报错提示信息,提高人机交互(一)

1、常规报错及处理 package com.example.demo.controller;import com.example.demo.service.IDemoService; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.w…

【JAVA】idea初步使用+JDK详细配置

1、官方下载idea 官网&#xff1a;Download IntelliJ IDEA – The Leading Java and Kotlin IDE (1)、下载教程 我下载没截屏&#xff0c;详细教程请看 原文&#xff1a;手把手教你JDKIDEA的安装和环境配置_idea配置jdk_快到锅里来呀的博客-CSDN博客 2、启动项目时候需要配置J…

Spring事件机制之ApplicationEvent

博主介绍&#xff1a;✌全网粉丝4W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

全栈性能测试工具:RunnerGo

随着自动化测试技术的不断进步&#xff0c;自动化测试已成为企业级应用的重要组成部分。然而&#xff0c;传统的性能测试工具往往复杂、繁琐&#xff0c;让企业陷入了两难的境地。软件测试正逐渐从手动测试向自动化测试转变&#xff0c;各种自动化测试工具和框架层出不穷&#…

手撕 LFU 缓存

大家好&#xff0c;我是 方圆。LFU 的缩写是 Least Frequently Used&#xff0c;简单理解则是将使用最少的元素移除&#xff0c;如果存在多个使用次数最小的元素&#xff0c;那么则需要移除最近不被使用的元素。LFU 缓存在 LeetCode 上是一道困难的题目&#xff0c;实现起来并不…

【Godot】解决游戏中的孤立/孤儿节点及分析器性能问题的分析处理

Godot 4.1 因为我在游戏中发现&#xff0c;越运行游戏变得越来越卡&#xff0c;当你使用 Node 节点中的 print_orphan_nodes() 方法打印信息的时候&#xff0c;会出现如下的孤儿节点信息 孤儿节点信息是以 节点实例ID - Stray Node: 节点名称(Type: 节点类型) 作为格式输出&a…

腾讯mini项目-【指标监控服务重构】2023-08-23

今日已办 进度和问题汇总 请求合并 feature/venus tracefeature/venus metricfeature/profile-otel-baserunner-stylebugfix/profile-logger-Syncfeature/profile_otelclient_enable_config 完成otel 开关 trace-采样metrice-reader 已经都在各自服务器运行&#xff0c;并接入…

创造性地解决冲突

1、冲突的根本原因是矛盾双方存在不可调和的目标冲突。 2、要知己知彼&#xff1a; 知己&#xff1a;就是对自己的问题、需求进行客观定义&#xff0c;说明需求和问题的意义或价值、阐述解决方案和期望效果&#xff1b; 知彼&#xff1a;站在对方立场&#xff0c;深挖对方真…

根据3d框的八个顶点坐标,求他的中心点,长宽高和yaw值(Python)

要从一个3D框的八个顶点求出它的中心点、长、宽、高和yaw值&#xff0c;首先需要明确框的几何形状和坐标点的顺序。通常这样的框是一个矩形体&#xff08;长方体&#xff09;&#xff0c;但其方向并不一定与坐标轴平行。 以下是一个步骤来解决这个问题&#xff1a; 求中心点&a…

Unity Bolt UGUI事件注册方式总结

Bolt插件提供了丰富的事件注册方式&#xff0c;开发者几乎不用编写任何代码就可以完成事件的注册&#xff0c;进行交互。下面是我使用UI事件注册的相关总结。 1、通过UI控件自身拖拽实现事件的注册。 Button的事件注册&#xff1a; 新建一个UnityEvent事件&#xff0c; Butt…

Kafka消费者组重平衡(二)

文章目录 概要重平衡通知机制消费组组状态消费端重平衡流程Broker端重平衡流程 概要 上一篇Kafka消费者组重平衡主要介绍了重平衡相关的概念&#xff0c;本篇主要梳理重平衡发生的流程。 为了更好地观察&#xff0c;数据准备如下&#xff1a; kafka版本&#xff1a;kafka_2.1…

nodejs定时任务

项目需求&#xff1a; 每5秒执行一次&#xff0c;多个定时任务错开&#xff0c;即cron表达式中斜杆前带数字&#xff0c;例如 ‘1/5 * * * * *’定时任务准时&#xff0c;延误低 搜索了nodejs的定时任务&#xff0c;其实不多&#xff0c;找到了以下三个常用的&#xff1a; n…

OpenCV中的HoughLines函数和HoughLinesP函数到底有什么区别?

一、简述 基于OpenCV进行直线检测可以使用HoughLines和HoughLinesP函数完成的。这两个函数之间的唯一区别在于,第一个函数使用标准霍夫变换,第二个函数使用概率霍夫变换(因此名称为 P)。概率版本之所以如此,是因为它仅分析点的子集并估计这些点都属于同一条线的概率。此实…

2D游戏开发和3D游戏开发有什么不同?

2D游戏开发和3D游戏开发是两种不同类型的游戏制作方法&#xff0c;它们之间有一些显著的区别&#xff1a; 1. 图形和视觉效果&#xff1a; 2D游戏开发&#xff1a; 2D游戏通常使用二维图形&#xff0c;游戏世界和角色通常在一个平面上显示。这种类型的游戏具有平面的外观&…

数据仓库模型设计V2.0

一、数仓建模的意义 数据模型就是数据组织和存储方法&#xff0c;它强调从业务、数据存取和使用角度合理存储数据。只有将数据有序的组织和存储起来之后&#xff0c;数据才能得到高性能、低成本、高效率、高质量的使用。 高性能&#xff1a;良好的数据模型能够帮助我们快速查询…

shell脚本命令

Shell命令是在类Unix操作系统中使用的命令行解释器&#xff08;shell&#xff09;中执行的命令。Shell命令可以用于执行系统命令、操作文件、进行文本处理、管理进程等。以下是一些常见的Shell命令&#xff1a; 1. ls&#xff1a;列出当前目录下的文件和文件夹。 2. cd&#x…

界面组件DevExpress WinForms v23.1亮点 - 全新升级HTML CSS模板

DevExpress WinForms拥有180组件和UI库&#xff0c;能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForms能完美构建流畅、美观且易于使用的应用程序&#xff0c;无论是Office风格的界面&#xff0c;还是分析处理大批量的业务数据&#xff0c;它都能轻松胜…

2020-2023中国高等级自动驾驶产业发展趋势研究-概念界定

1.1 概念界定 自动驾驶发展过程中&#xff0c;中国出现了诸多专注于研发L3级以上自动驾驶的公司&#xff0c;其在业界地位也越来越重要。本报告围绕“高等级自动驾驶” 展开&#xff0c;并聚焦于该技术2020-2023年在中国市场的变化趋势进行研究。 1.1.1 什么是自动驾驶 自动驾驶…

C#中的方法

引言 在C#编程语言中&#xff0c;方法是一种封装了一系列可执行代码的重要构建块。通过方法&#xff0c;我们可以将代码逻辑进行模块化和复用&#xff0c;提高代码的可读性和可维护性。本文将深入探讨C#中的方法的定义、参数传递、返回值、重载、递归等方面的知识&#xff0c;…

小型水库雨水情测报和大坝安全监测解决方案

一、建设背景 我国小型水库数量众多&#xff0c;大多由农村集体经济组织管理&#xff0c;灌溉、供水、防洪、生 态效益突出&#xff0c;是农业生产、农民生活、农村发展和区域防洪的重要基础设施&#xff0c;实施乡 村振兴战略和生态文明建设的重要支撑保障。由于小型水库工程存…