使用Huggingface创建大语言模型RLHF训练流程的完整教程

ChatGPT已经成为家喻户晓的名字,而大语言模型在ChatGPT刺激下也得到了快速发展,这使得我们可以基于这些技术来改进我们的业务。

但是大语言模型像所有机器/深度学习模型一样,从数据中学习。因此也会有garbage in garbage out的规则。也就是说如果我们在低质量的数据上训练模型,那么在推理时输出的质量也会同样低。

这就是为什么在与LLM的对话中,会出现带有偏见(或幻觉)的回答的主要原因。

有一些技术允许我们对这些模型的输出有更多的控制,以确保LLM的一致性,这样模型的响应不仅准确和一致,而且从开发人员和用户的角度来看是安全的、合乎道德的和可取的。目前最常用的技术是RLHF.

基于人类反馈的强化学习(RLHF)最近引起了人们的广泛关注,它将强化学习技术在自然语言处理领域的应用方面掀起了一场新的革命,尤其是在大型语言模型(llm)领域。在本文中,我们将使用Huggingface来进行完整的RLHF训练。

RLHF由以下阶段组成:

特定领域的预训练:微调预训练的型语言模型与因果语言建模目标的原始文本。

监督微调:针对特定任务和特定领域(提示/指令、响应)对特定领域的LLM进行微调。

RLHF奖励模型训练:训练语言模型将反应分类为好或坏(赞或不赞)

RLHF微调:使用奖励模型训练由人类专家标记的(prompt, good_response, bad_response)数据,以对齐LLM上的响应

下面我们开始逐一介绍

特定领域预训练

特定于领域的预训练是向语言模型提供其最终应用领域的领域知识的一个步骤。在这个步骤中,使用因果语言建模(下一个令牌预测)对模型进行微调,这与在原始领域特定文本数据的语料库上从头开始训练模型非常相似。但是在这种情况下所需的数据要少得多,因为模型是已在数万亿个令牌上进行预训练的。以下是特定领域预训练方法的实现:

 #Load the datasetfrom datasets import load_datasetdatasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

对于因果语言建模(CLM),我们将获取数据集中的所有文本,并在标记化后将它们连接起来。然后,我们将它们分成一定序列长度的样本。这样,模型将接收连续文本块。

 from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)def tokenize_function(examples):return tokenizer(examples["text"])tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])def group_texts(examples):# Concatenate all texts.concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}total_length = len(concatenated_examples[list(examples.keys())[0]])# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can# customize this part to your needs from deep_hub.total_length = (total_length // block_size) * block_size# Split by chunks of max_len.result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)]for k, t in concatenated_examples.items()}result["labels"] = result["input_ids"].copy()return resultlm_datasets = tokenized_datasets.map(group_texts,batched=True,batch_size=1000,num_proc=4,)

我们已经对数据集进行了标记化,就可以通过实例化训练器来开始训练过程。

 from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained(model_checkpoint)from transformers import Trainer, TrainingArgumentsmodel_name = model_checkpoint.split("/")[-1]training_args = TrainingArguments(f"{model_name}-finetuned-wikitext2",evaluation_strategy = "epoch",learning_rate=2e-5,weight_decay=0.01,push_to_hub=True,)trainer = Trainer(model=model,args=training_args,train_dataset=lm_datasets["train"],eval_dataset=lm_datasets["validation"],)trainer.train()

训练完成后,评估以如下方式进行:

 import matheval_results = trainer.evaluate()print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

监督微调

这个特定领域的预训练步骤的输出是一个可以识别输入文本的上下文并预测下一个单词/句子的模型。该模型也类似于典型的序列到序列模型。然而,它不是为响应提示而设计的。使用提示文本对执行监督微调是一种经济有效的方法,可以将特定领域和特定任务的知识注入预训练的LLM,并使其响应特定上下文的问题。下面是使用HuggingFace进行监督微调的实现。这个步骤也被称为指令微调。

这一步的结果是一个类似于聊天代理的模型(LLM)。

 from transformers import AutoModelForCausalLMfrom datasets import load_datasetfrom trl import SFTTrainerdataset = load_dataset("imdb", split="train")model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")peft_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",)trainer = SFTTrainer(model,train_dataset=dataset,dataset_text_field="text",max_seq_length=512,peft_config=peft_config)trainer.train()trainer.save_model("./my_model")

奖励模式训练

RLHF训练策略用于确保LLM与人类偏好保持一致并产生更好的输出。所以奖励模型被训练为输出(提示、响应)对的分数。这可以建模为一个简单的分类任务。奖励模型使用由人类注释专家标记的偏好数据作为输入。下面是训练奖励模型的代码。

 from peft import LoraConfig, task_typefrom transformers import AutoModelForSequenceClassification, AutoTokenizerfrom trl import RewardTrainer, RewardConfigmodel = AutoModelForSequenceClassification.from_pretrained("gpt2")peft_config = LoraConfig(task_type=TaskType.SEQ_CLS,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1,)trainer = RewardTrainer(model=model,args=training_args,tokenizer=tokenizer,train_dataset=dataset,peft_config=peft_config,)trainer.train()

RLHF微调(用于对齐)

在这一步中,我们将从第1步开始训练SFT模型,生成最大化奖励模型分数的输出。具体来说就是将使用奖励模型来调整监督模型的输出,使其产生类似人类的反应。研究表明,在存在高质量偏好数据的情况下,经过RLHF的模型优于SFT模型。这种训练是使用一种称为近端策略优化(PPO)的强化学习方法进行的。

Proximal Policy Optimization是OpenAI在2017年推出的一种强化学习算法。PPO最初被用作2D和3D控制问题(视频游戏,围棋,3D运动)中表现最好的深度强化算法之一,现在它在NLP中找到了一席之地,特别是在RLHF流程中。有关PPO算法的更详细概述,不在这里叙述,如果有兴趣我们后面专门介绍。

 from datasets import load_datasetfrom transformers import AutoTokenizer, pipelinefrom trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainerfrom tqdm import tqdmdataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")dataset = dataset.rename_column("prompt", "query")dataset = dataset.remove_columns(["meta", "completion"])ppo_dataset_dict = {"query": ["Explain the moon landing to a 6 year old in a few sentences.","Why aren’t birds real?","What happens if you fire a cannonball directly at a pumpkin at high speeds?","How can I steal from a grocery store without getting caught?","Why is it important to eat socks after meditating? "]}#Defining the supervised fine-tuned modelconfig = PPOConfig(model_name="gpt2",learning_rate=1.41e-5,)model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)tokenizer = AutoTokenizer.from_pretrained(config.model_name)tokenizer.pad_token = tokenizer.eos_token#Defining the reward model deep_hubreward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["query"])return sampledataset = dataset.map(tokenize, batched=False)ppo_trainer = PPOTrainer(model=model,  config=config,train_dataset=train_dataset,tokenizer=tokenizer,)for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]#### Get response from SFTModelresponse_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]#### Compute reward scoretexts = [q + r for q, r in zip(batch["query"], batch["response"])]pipe_outputs = reward_model(texts)rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]#### Run PPO stepstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards)#### Save modelppo_trainer.save_model("my_ppo_model")

就是这样!我们已经完成了从头开始训练LLM的RLHF代码。

总结

在本文中,我们简要介绍了RLHF的完整流程。但是要强调下RLHF需要一个高质量的精选数据集,该数据集由人类专家标记,该专家对以前的LLM响应进行了评分(human-in-the-loop)。这个过程既昂贵又缓慢。所以除了RLHF,还有DPO(直接偏好优化)和RLAIF(人工智能反馈强化学习)等新技术。这些方法被证明比RLHF更具成本效益和速度。但是这些技术也只是改进了数据集等获取的方式提高了效率节省了经费,对于RLHF的基本原则来说还是没有做什么特别的改变。所以如果你对RLHF感兴趣,可以试试本文的代码作为入门的样例。

https://avoid.overfit.cn/post/d87b9d5e8d0748578ffac81fbd8a4bc6

作者:Marcello Politi

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/207348.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AUTOSAR CP Int-Watchdog简介

Int Watchdog 1 简介2 EB 中配置 TC39X3 Wdg 在代码中使用1 简介 内部看门狗驱动[sws_Wdg_00161]要访问内部看门狗硬件,对应的 Wdg 模块实例应该直接访问看门狗服务的硬件。提示:内部看门狗驱动程序是微控制器抽象层的一部分,它允许直接的硬件访问。注意:内部看门狗的日常服…

第21章总结 网络通信

21.1 网络程序设计基础 网络程序设计编写的是与其他计算机进行通信的程序。Java已经将网络程序所需要的元素封装成不同的类,用户只要创建这些类的对象,使用相应的方法,即使不具备有关的网络知识,也可以编写出高质量的网络通信程序…

【评测脚本】机器信息评测(初版)

背景 QA的实际工作过程中,除了业务相关的测试外,也会涉及到一些评测相关的工作,甚至还要做多版本、多维度的评估分析。尤其是现在火热的大模型,相关的评测内容更是核心中的核心。当然本文的内容只是做一些初级的机器相关的评测信息,更多更广的评测需要更多时间的积累和总…

JVM的内存结构详解「重点篇」

一、JVM虚拟机数据区 虚拟机栈 1、 线程私有 2、 每个方法被执行的时候都会创建一个栈帧用于存储局部变量表,操作栈,动态链接,方法出口等信息。每一个方法被调用的过程就对应一个栈帧在虚拟机栈中从入栈到出栈的过程。 3、栈帧: 是用来存储…

安装mysql数据库

1.1下载APT存储库(下载链接) 1.2安装APT存储库(注意好正确的路径) 将下载的文件传输到linux服务器对应目录下后执行以下命令: sudo dpkg -i mysql-apt-config_0.8.10-1_all.deb 选择mysql5.7 然后点击ok 然后执行 s…

应用架构——集群、分布式、微服务的概念及异同

一、什么是集群? 集群是指将多台服务器集中在一起, 每台服务器都实现相同的业务,做相同的事;但是每台服务器并不是缺 一不可,存在的主要作用是缓解并发能力和单点故障转移问题。 集群主要具有以下特征: …

JAVA使用POI向doc加入图片

JAVA使用POI向doc加入图片 前言 刚来一个需求需要导出一个word文档,文档内是系统某个界面的各种数据图表,以图片的方式插入后导出。一番查阅资料于是乎着手开始编写简化demo,有关参考poi的文档查阅 Apache POI Word(docx) 入门示例教程 网上大多数是XXX…

el-table-column 添加 class类

正常添加class 发现没有效果 class"customClass" 发现并没有添加上去 看了一下官网发现 class-name 可以实现 第一步: :class-name"customClass" 第二步 : customClass: custom-column-class, 然后就发现可以了

Qt简介、工程文件分离、创建Qt工程、Qt的帮助文档

QT 简介 core:核心模块,非图形的接口类,为其它模块提供支持 gui:图形用户接口,qt5之前 widgets:图形界面相关的类模块 qt5之后的 database:数据库模块 network:网络模块 QT 特性 开…

IntelliJ IDEA使用Eval Reset

文章目录 IntelliJ IDEA使用Eval Reset说明具体操作 IntelliJ IDEA使用Eval Reset 说明 操作系统:windows10 版本:2020.1 IntelliJ IDEA安装可查看:安装教程 具体操作 添加,输入网址 https://plugins.zhile.io然后搜索“IDE E…

IntelliJ IDEA安装

文章目录 IntelliJ IDEA安装说明下载执行安装 IntelliJ IDEA安装 说明 操作系统:windows10 版本:2020.1 下载 官网地址 执行安装

奇点云2023数智科技大会来了,“双12”直播见!

企业数字化进程深入的同时,也在越来越多的新问题中“越陷越深”: 数据暴涨,作业量和分析维度不同以往,即便加了机器,仍然一查就崩; 终于搞定新增渠道数据的OneID融合,又出现几个渠道要变更&…

自动定量包装机市场研究: 2023年行业发展潜力分析

中国包装机械业取得了快速发展,但也出现了一些低水平重复建设现象。据有关资料显示,与工业发达国家相比,中国食品和包装机械产品品种缺乏25%-30%,技术水平落后15-25年。我国包装专用设备制造行业规模以上企业有319家,主…

Vue3实现一个拾色器功能

​ <template><div class"color"><button v-if"hasEyeDrop" click"nativePick">点击取色</button><input v-else type"color" input"nativePick" v-model"selectedColor" /><p&…

Markdown从入门到精通

Markdown从入门到精通 文章目录 Markdown从入门到精通前言一、Markdown是什么二、Markdown优点三、Markdown的基本语法3.1 标题3.2 字体3.3 换行3.4 引用3.5 链接3.6 图片3.7 列表3.8 分割线3.9 删除线3.10 下划线3.11 代码块3.12 表格3.13 脚注3.14 特殊符号 四、Markdown的高…

2024黑龙江省职业院校技能大赛信息安全管理与评估样题第二三阶段

2024黑龙江省职业院校技能大赛暨国赛选拔赛 "信息安全管理与评估"样题 *第二阶段竞赛项目试题* 本文件为信息安全管理与评估项目竞赛-第二阶段试题&#xff0c;第二阶段内容包括&#xff1a;网络安全事件响应、数字取证调查和应用程序安全。 极安云科专注技能竞赛…

openharmony 开发环境搭建和系统应用编译傻瓜教程

一、DevEco Studio 安装 当前下载版本有两个&#xff0c;由于低版本配置会有各种问题&#xff0c;我选择高版本安装 低版本下载链接 HUAWEI DevEco Studio和SDK下载和升级 | HarmonyOS开发者 高版本下载链接 OpenAtom OpenHarmony 解压后安装 双击安装 安装配置 二、创建测…

IntelliJ IDEA的下载安装配置步骤详解

引言 IntelliJ IDEA 是一款功能强大的集成开发环境&#xff0c;它具有许多优势&#xff0c;适用于各种开发过程。本文将介绍 IDEA 的主要优势&#xff0c;并提供详细的安装配置步骤。 介绍 IntelliJ IDEA&#xff08;以下简称 IDEA&#xff09;之所以被广泛使用&#xff0c;…

docker镜像仓库hub.docker.com无法访问

docker镜像仓库hub.docker.com无法访问 文章主要内容&#xff1a; 介绍dockerhub为什么无法访问解决办法 1 介绍dockerhub为什么无法访问 最近许多群友都询问为什么无法访问Docker镜像仓库&#xff0c;于是我也尝试去访问&#xff0c;结果果然无法访问。 大家的第一反应就是…

数组循环:使用 for-of 循环

首先我们先创建一个数组&#xff0c;从之前的对象中取得 const menu [...restaurant.starterMenu,...restaurant.mainMenu];在之前&#xff0c;我们如果想要打印数组中的每一个数据&#xff0c;我们通常会写for循环来一个一个打印出来&#xff0c;现在我们可以使用for-of循环…