LLM微调(三)| 大模型中RLHF + Reward Model + PPO技术解析

        本文将深入探讨RLHF(Reinforcement Learning with Human Feedback)、RM(reward model)和PPO(Proximal Policy Optimizer)算法的概念。然后,通过代码演示使用RLHF训练自己的大模型和奖励模型RM。最后,简要深入研究模型毒性和幻觉,以及如何创建一个更面向模型的产品或更有益、诚实、无害、可靠,并与人类反馈对齐的生成人工智能的生命周期。

一、RLHF(Reinforcement Learning with Human Feedback)

图片

       先来举一个简单的例子——想象一下,我们正在创建一个LLM会话式人工智能产品模型,它可以为经历艰难时期的人类提供治疗,如果我们训练了一个大模型,但没有使其与人类保持一致,它通过药物滥用等方式为这些人提供让他们感觉更好和最佳的非法方式,这将导致伤害、缺乏有效的可靠性和帮助。正如OpenAI CTO所说,大模型领域正在蓬勃发展,大模型更可靠、更一致、产生更少的幻觉,唯一可能的方法是使用来自不同人群的人类反馈,以及其他方式,如RAG、Langchain,来提供基于上下文的响应。生成人工智能生命周期可以最大限度地提高了帮助性,最大限度地减少了困难,避免了与危险话题的讨论和参与。

       在深入了解RLHF之前,我们先介绍一下强化学习的基本原理,如下图所示:

图片

     RL是Agent与环境Environment不断交互的过程,首先Agent处于Environment的某个state状态下,然后执行一个action,就会对环境产生影响,从而进入另一个state下,如果对Environment是好的或者是期待的,那么会得到正向的reward,否则是负向的,最终一般是让整个迭代过程中累积reward最大。

二、在大模型的什么环节使用RL呢?

图片

       这里有Agent、Environment和大模型的Current Context,在这种情况下,策略就是知道我们预训练或者微调过的LLM模型。现在我们希望能够在给定的域中生成文本,对吗?因此,我们采取行动,LLM获取当前上下文窗口和环境上下文,并基于该动作,获得奖励。带着奖励的策略就是人类反馈的地方。

三、奖励模型Reward Model介绍

       基于人类的反馈数据来训练一个奖励模型,该模型会在RLHF中被调用,并且不需要人类的参与,就可以根据用户不同的Prompt来分配不同的奖励reward,这个过程被称为”Rollout“。

那么如何构建人类反馈的数据集呢?

图片

数据集格式,如下图所示:

图片

四、奖励模型Reward Model训练

有了人类反馈的数据集,我们就可以基于如下流程来训练RM模型:

图片

五、使用RLHF (PPO & KL Divergence)进行微调

  1. 把一个Prompt数据集输入给初始LLM中;

  2. 给instruct LLM输入大量的Prompts,并得到一些回复;

  3. 把Prompt补全输入给已经训练好的RM模型,RM会生成对应的score,然后把这些score输入给RL算法;

  4. 我们在这里使用的RL算法是PPO,会根据Prompt生成一些回复,对平均值进行排序,使用反向传播来评估响应,最后将最优的回复输入给instruct LLM;

  5. 进行几次迭代后,会得到一个奖励模型,但这有一个不利的方面。

PS:如果我们的模型不断接受积极价值观的训练,然后开始提供奇怪、模糊和不符合人类的输出,会怎么样?

图片

        为了解决上述问题,我们采用如下流程:

图片

       首先使用参考模型,冻结其中的所有权重,作为我们人类对齐模型的参考点,然后基于这种迁移,我们使用KL散度惩罚添加到奖励中,这样当模型产生幻觉时,它会使模型回到参考模型附近,以提供积极但不奇怪的积极反应。我们可以使用PEFT适配器来训练我们的PPO模型,并使模型在推出时越来越一致。

六、使用RLHF (PEFT + LORA + PPO)微调实践

6.1 安装相关的包

!pip install --upgrade pip!pip install --disable-pip-version-check \    torch==1.13.1 \    torchdata==0.5.1 --quiet​​​​​
!pip install \    transformers==4.27.2 \    datasets==2.11.0 \    evaluate==0.4.0 \    rouge_score==0.1.2 \    peft==0.3.0 --quiet# Installing the Reinforcement Learning library directly from github.!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

6.2 导入相关的包​​​​​​​

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfigfrom datasets import load_datasetfrom peft import PeftModel, PeftConfig, LoraConfig, TaskType
# trl: Transformer Reinforcement Learning libraryfrom trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHeadfrom trl import create_reference_modelfrom trl.core import LengthSampler
import torchimport evaluate
import numpy as npimport pandas as pd
# tqdm library makes the loops show a smart progress meter.from tqdm import tqdmtqdm.pandas()

6.3 加载LLaMA 2模型​​​​​​​

from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-34b-Instruct-hf")model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-34b-Instruct-hf")huggingface_dataset_name = "knkarthick/dialogsum"dataset_original = load_dataset(huggingface_dataset_name)dataset_original

6.4 预处理数据集​​​​​​​

def build_dataset(model_name,    dataset_name,    input_min_text_length,    input_max_text_length):“””Preprocess the dataset and split it into train and test parts.Parameters:- model_name (str): Tokenizer model name.- dataset_name (str): Name of the dataset to load.- input_min_text_length (int): Minimum length of the dialogues.- input_max_text_length (int): Maximum length of the dialogues.Returns:- dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.“””    # load dataset (only “train” part will be enough for this lab).    dataset = load_dataset(dataset_name, split=”train”)    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.    dataset = dataset.filter(lambda x: len(x[“dialogue”]) > input_min_text_length and len(x[“dialogue”]) <= input_max_text_length, batched=False)    # Prepare tokenizer. Setting device_map=”auto” allows to switch between GPU and CPU automatically.    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=”auto”)    def tokenize(sample):        # Wrap each dialogue with the instruction.        prompt = f”””        Summarize the following conversation.        {sample[“dialogue”]}        Summary:        “””        sample[“input_ids”] = tokenizer.encode(prompt)        # This must be called “query”, which is a requirement of our PPO library.        sample[“query”] = tokenizer.decode(sample[“input_ids”])        return sample    # Tokenize each dialogue.    dataset = dataset.map(tokenize, batched=False)    dataset.set_format(type=”torch”)# Split the dataset into train and test parts.    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)    return dataset_splitsdataset = build_dataset(model_name=model_name,    dataset_name=huggingface_dataset_name,    input_min_text_length=200,    input_max_text_length=1000)print(dataset)

6.5 抽取模型参数​​​​​​​

def print_number_of_trainable_model_parameters(model):    trainable_model_params = 0    all_model_params = 0    for _, param in model.named_parameters():        all_model_params += param.numel()        if param.requires_grad:            trainable_model_params += param.numel()    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

6.6 将适配器添加到原始salesforce代码生成模型中。现在,我们需要将它们传递到构建的PEFT模型,也将is_trainable=True。​​​​​​​

lora_config = LoraConfig(    r=32, # Rank    lora_alpha=32,    target_modules=["q", "v"],    lora_dropout=0.05,    bias="none",    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5)​​​​​​
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,                                               torch_dtype=torch.bfloat16)peft_model = PeftModel.from_pretrained(model,                                        '/kaggle/input/generative-ai-with-llms-lab-3/lab_3/peft-dialogue-summary-checkpoint-from-s3/',                                        lora_config=lora_config,                                       torch_dtype=torch.bfloat16,                                        device_map="auto",                                                                              is_trainable=True)print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')​​​​​​​
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,torch_dtype=torch.bfloat16,is_trainable=True)print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')print(ppo_model.v_head)​​​​​​​
ref_model = create_reference_model(ppo_model)print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

  使用Meta AI基于RoBERTa的仇恨言论模型(https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target)作为奖励模型。这个模型将输出logits,然后预测两类的概率:notate和hate。输出另一个状态的logits将被视为正奖励。然后,模型将使用这些奖励值通过PPO进行微调。​​​​​​​

toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")print(toxicity_model.config.id2label)​​​​​​
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_idslogits = toxicity_model(input_ids=toxicity_input_ids).logitsprint(f'logits [not hate, hate]: {logits.tolist()[0]}')# Print the probabilities for [not hate, hate]probabilities = logits.softmax(dim=-1).tolist()[0]print(f'probabilities [not hate, hate]: {probabilities}')# get the logits for "not hate" - this is the reward!not_hate_index = 0nothate_reward = (logits[:, not_hate_index]).tolist()print(f'reward (high): {nothate_reward}')

6.7 评估模型的毒性​​​​​​​

toxicity_evaluator = evaluate.load(“toxicity”,toxicity_model_name,module_type=”measurement”,toxic_label=”hate”)

参考文献

[1] https://medium.com/@madhur.prashant7/rlhf-reward-model-ppo-on-llms-dfc92ec3885f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206597.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙OS应用开发之最简单的程序

鸿蒙OS应用开发之最简单的程序 前面介绍怎么样安装鸿蒙应用开发的环境&#xff0c;然后试着运行起来&#xff0c;并安装运行的虚拟机&#xff0c;以及对应9.0版本的API和SDK等软件。这样就具备了基本的开发基础&#xff0c;就可以进入创建应用程序开发了。 在我们起飞之前&…

idea使用maven的package打包时提示“找不到符号”或“找不到包”

介绍&#xff1a;由于我们的项目是多模块开发项目&#xff0c;在打包时有些模块内容更新导致其他模块在引用该模块时不能正确引入。 情况一&#xff1a;找不到符号 情况一&#xff1a;找不到包 错误代码部分展示&#xff1a; Failure to find com.xxx.xxxx:xxx:pom:0.5 in …

报错处理集

这个报错处理集的错误来源于编译arm平台的so文件产生的。但是后续可以补充成linux一个大的错误处理集。 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 第一次整理的时间是2023年12月8日10:05:59&#xff0c;以下错误来源于欧拉系统编译…

数字人知识库:Awesome-Talking-Head-Synthesis

数字人知识库&#xff1a;Awesome-Talking-Head-Synthesis 文章目录 数字人知识库&#xff1a;Awesome-Talking-Head-SynthesisDatasetsSurveyAudio-drivenText-drivenNeRF & 3DMetricsTools & SoftwareSlides & Presentations Gihub&#xff1a;https://github.co…

查看电脑cuda版本

1.找到NVODIA控制面板 输入NVIDIA搜索即可 出现NVIDIA控制面板 点击系统信息 2.WINR 输入nvidia-smi 检查了一下&#xff0c;电脑没用过GPU&#xff0c;连驱动都没有 所以&#xff0c;装驱动…… 选版本&#xff0c;下载 下载后双击打开安装 重新输入nvidia-smi 显示如下…

Python函数默认参数设置

在某些情况下&#xff0c;程序需要在定义函数时为一个或多个形参指定默认值&#xff0c;这样在调用函数时就可以省略为该形参传入参数值&#xff0c;而是直接使用该形参的默认值。 为形参指定默认值的语法格式如下&#xff1a; 形参名 默认值 从上面的语法格式可以看出&…

Java实现布隆过滤器

一、概述 布隆过滤器本质上是一个很长的二进制数组&#xff0c;主要用来判断一个数据存不存在数组里&#xff0c;如果存在就用1表示&#xff0c;不存在用0表示&#xff0c;它的优点是空间效率和查询时间都比一般的算法要好的多&#xff0c;缺点是有一定的误识别率和删除困难。 …

【Python系列】Python函数

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

uni-app 微信小程序之好看的ui登录页面(一)

文章目录 1. 页面效果2. 页面样式代码 更多登录ui页面 uni-app 微信小程序之好看的ui登录页面&#xff08;一&#xff09; uni-app 微信小程序之好看的ui登录页面&#xff08;二&#xff09; uni-app 微信小程序之好看的ui登录页面&#xff08;三&#xff09; uni-app 微信小程…

经纬恒润以太网网关,智能时代网络通关

汽车产业新四化步伐持续加速&#xff0c;智能网联逐渐成为整车标配&#xff0c;随着近年来相关政策频出以及对网联需求和功能的深度挖掘与发展&#xff0c;中国本土市场及本土供应商在这场新浪潮中逐渐走向C位。经纬恒润深耕智能网联领域多年&#xff0c;先后推出四代网关产品&…

JavaSE基础50题:18. 写一个递归方法,输入一个非负整数,返回组成它的数字之和。例如:输入1729,则应该返回1+7+2+9,它的和是19

概述 写一个递归方法&#xff0c;输入一个非负整数&#xff0c;返回组成它的数字之和。例如&#xff1a;输入1729&#xff0c;则应该返回1729&#xff0c;它的和是19。 代码 public class P18 {public static int func(int n) {if (n < 10) {return n;}return n%10 func…

文章解读与仿真程序复现思路——中国电机工程学报EI\CSCD\北大核心《考虑气电联合需求响应的气电综合能源配网系统协调优化运行》

这个标题涉及到一个涉及气体&#xff08;天然气&#xff09;和电力的综合能源配网系统&#xff0c;并且强调了考虑气电联合需求响应的协调优化运行。让我们逐步解读&#xff1a; 气电综合能源配网系统&#xff1a; 这指的是一个结合了气体&#xff08;通常是天然气&#xff09;…

【Java数据结构 -- List和ArrayList与顺序表】

List和ArrayList与顺序表 一. List1.1 List介绍2.1 常见接口介绍3.1 List的使用 二. ArrayList与顺序表1.线性表2.顺序表2.1 接口的实现 3.ArrayList简介4. ArrayList使用4.1 ArrayList的构造 4.2 ArrayList常见操作4.3 ArrayList的遍历4.4 ArrayList的扩容机制5. ArrayList的具…

芯科科技领先无线技术助力控客赋能智慧亚运村

采用芯科科技MG21无线SoC的控客智能家居解决方案为杭州亚运会媒体村打造智能、舒适、便捷、安全的生活体验 中国&#xff0c;北京 - 2023年12月4日 – 致力于以安全、智能无线连接技术&#xff0c;建立更互联世界的全球领导厂商Silicon Labs&#xff08;亦称“芯科科技”&…

名创优品出海更难了,Q3净利增速放缓

近日&#xff0c;有媒体报道称&#xff0c;名创优品受市场流传针对公司的做空报告&#xff0c;叠加高管团队两次减持&#xff0c;共计180万股&#xff0c;若对应25美元的价格&#xff0c;则达4500万美元(约3.2亿元)。 12月5日&#xff0c;其港股收盘股价下跌14.04%&#xff1b…

Javascript 函数介绍

Javascript 函数介绍 很多教程书一上来就讲解一堆语法&#xff0c;例如函数定义、函数调用什么。等读者看完了函数这一章都没搞懂什么是函数。 在讲解什么叫函数之前&#xff0c;我们先看下面一段代码&#xff1a; <!DOCTYPE html> <html xmlns"http://www.w3.…

文心一言 VS 讯飞星火 VS chatgpt (152)-- 算法导论12.2 8题

八、用go语言&#xff0c;证明&#xff1a;在一棵高度为 h 的二叉搜索树中&#xff0c;不论从哪个结点开始&#xff0c;k 次连续的 TREE-SUCCESSOR 调用所需时间为 O(kh)。 文心一言&#xff1a; 要证明在一棵高度为 h 的二叉搜索树中&#xff0c;不论从哪个结点开始&#xf…

VSCODE 运行C程序缓慢解决方法之一

最近更换了mingw的版本&#xff0c;安装路径与之前的mingw路径不大一样。结果发现代码运行的时候很慢&#xff0c;弹出窗口后&#xff0c;迟迟没有打印任何东西&#xff0c;就像卡死了一样。试过网上说的一堆方法&#xff0c;没有什么用。 我按照以下流程进行检查: 1.检查min…

地震反演基础知识3

文章目录 地震勘探原理1 地震波1. 1 地震波概念1. 2 波的传播1. 2. 1 波传播的基本原理1. 2. 2 地震波的反射,折射,透射的1. 2. 3 地震子波&#xff08;seismic wavelet&#xff09;1. 2. 4 地震合成记录 2 地震时距曲线2.1 地震时距曲线作用2.2 不同波的时距曲线2.2.1 直达波时…

【Jeecg Boot 3 - 保姆级】第1节 docker + redis + nginx + redis一键安装启动

一、前言 ▶ JEECG-BOOT 开源版难以吃透的原因 ▶ 为了针对上面痛点&#xff0c;笔者做了如下安排 ▶ 你能收获什么 二、效果(第一节效果) ▶ 启动后端 &#xff1e; 日志 &#xff1e; 接口文档 ▶ 启动前端 三、准备工作 四、实战 ▶ 1、服务器安装 Stag…