强化学习-RLHF-PPO入门

一、定义

  1. 强化学习微调分类
  2. RM模型 数据集格式
  3. 训练流程
  4. Reward 模型训练流程(分类模型,积极为1,消极为0) AutoModelForSequenceClassification
  5. Reward 模型训练案例
  6. PPO模型训练流程
  7. PPO模型训练案例

二、实现

  1. 强化学习微调分类
    RLHF:基于人类反馈对语言模型进行强化学习, 分两步:
    1. RM (Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好。
    2 RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本.
    DPO(Direct Preference Optimization): 直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好。
    RLHF主要是进行对齐微调, 目标是将大语言模型的行为与人类的价值观或偏好对齐。
    PPO: (Proximal Policy Optimization,近端策略优化)是一种在强化学习领域广泛使用的算法.

  2. RM模型 数据集格式

{conversations:  [0: {from:  
"human",value:  "国会的转发 美国国会由众议院和参议院组成,每两年换届一次(参议员任期为6年,但参议院选举是错位的。是更常见地转发国会议员还是来自国会外部?"}],
chosen:  {from:  "gpt",value:  "计算推文的政党边际概率,我们可以使用以下代码这表明大多数转发不是来自国会议员,而是来自国会之外。"},
rejected:  {from:  "gpt",value:  "回答问题的第(计算转发国会议员或来自国会以外的人的边际概率"}}

其中chosen 代表是好的回答, rejected代表的是不好的回答

  1. 训练流程
    在这里插入图片描述
    训练reward Model---->PPO模型

  2. Reward 模型训练流程(激励模型为深度学习模型)
    数据处理:

def preprocess_function(examples):new_examples = {"input_ids_chosen": [],"attention_mask_chosen": [],"input_ids_rejected": [],"attention_mask_rejected": [],}for chosen, rejected in zip(examples["chosen"], examples["rejected"]):tokenized_chosen = tokenizer(chosen)tokenized_rejected = tokenizer(rejected)new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])return new_examples

训练求损失:AutoModelForSequenceClassification 分类模型

model = AutoModelForSequenceClassification.from_pretrained(model_config.model_name_or_path, num_labels=1, **model_kwargs
)
def compute_loss(self,model: Union[PreTrainedModel, nn.Module],inputs: Dict[str, Union[torch.Tensor, Any]],return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:if not self.use_reward_data_collator:warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding,"" if you are using a custom data collator make sure you know what you are doing or"" implement your own compute_loss method.")rewards_chosen = model(input_ids=inputs["input_ids_chosen"],attention_mask=inputs["attention_mask_chosen"],return_dict=True,)["logits"]rewards_rejected = model(input_ids=inputs["input_ids_rejected"],attention_mask=inputs["attention_mask_rejected"],return_dict=True,)["logits"]# calculate loss, optionally modulate with marginif "margin" in inputs:loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()else:loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()if return_outputs:return loss, {"rewards_chosen": rewards_chosen,"rewards_rejected": rewards_rejected,}return loss
  1. Reward 模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py

  2. PPO模型训练流程

在这里插入图片描述

步骤:
1. 语言模型预测
2. 激活模型评估(分类模型),1 代表积极,0 代表消极
3. PPO算法优化。
数据:

def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sample
  1. 加载模型, 参考模型(参考模型可以为None)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# Get response from gpt2    待训练的模型响应,参考模型响应
response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
)
batch["response"] = tokenizer.batch_decode(response_tensors)
batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)
  1. 激活模型评估(分类模型),1 代表积极,0 代表消极
2. 获取激励值
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)                      #激励函数
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]   #激励值
  1. PPO算法优化。
#   问题query  、  模型响应   、激励值
#其中上图优化模块,均在step 方法中实现。
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
step内部:all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(self.model,queries,responses,model_inputs,response_masks=response_masks,return_logits=full_kl_penalty,)with self.optional_peft_ctx():ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(self.model if self.is_peft_model else self.ref_model,queries,responses,model_inputs,return_logits=full_kl_penalty,)
  1. PPO模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/33977.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习之数学基础(七)~过拟合(over-fitting)和欠拟合(under-fitting)

目录 1. 过拟合与欠拟合 1.1 Preliminary concept 1.2 过拟合 over-fitting 1.3 欠拟合 under-fitting 1.4 案例解析:黑天鹅 1. 过拟合与欠拟合 1.1 Preliminary concept 误差 经验误差:模型对训练集数据的误差。泛化误差:模型对测试…

RabbitMQ 消息传递

消息何去何从 mandatory和immediate是channel.basicPublish方法中的两个参数,他们都有当消息传递过程中不可达目的地时将消息返回给生产者的功能。RabbitMQ提供的备份交换器可以将未能被交换器路由的消息(没有绑定队列或者没有匹配的绑定)存…

学习C++第三天——对引用的深入了解

引用 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开辟内存空 间,它和它引用的变量共用同一块内存空间。 一个变量可以有多个引用: int main() {//一个变量可以有多个引用int a 10;int& b …

OpenAI 收购桌面实时协作公司 Multi;iOS 18 开放 iPhone 镜像测试丨RTE 开发者日报 Vol.231

开发者朋友们大家好: 这里是 「RTE 开发者日报」 ,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE(Real-Time Engagement) 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

c++ 设计模式 的课本范例(上)

( 0 ) 这里补充面向对象设计的几个原则: 开闭原则OCP : 面向增补开放,面向代码修改关闭。其实反映到代码设计上就是类的继承,通过继承与多态,可以不修改原代码,又增加新的类似的功能。 依赖倒置原则 Depend…

如何从iPhone恢复错误删除的照片

嘿,iPhone 用户!作为一名苹果专业人士,我见过相当多的“哎呀,我删除了它!”的时刻。今天,我在这里指导您完成从iPhone中恢复那些珍贵的,错误删除的照片的迷宫。坐下来,拿起你的设备&…

分压电路 ADC计算电压 【老板再也不能开除我了 】

经典分压电路 一个电压过来 adc这里的电压等于: 如是12位adc 那么他最大值就是4095 如参考电压是5v 则:5v/4095 实际电压V*(R2/(R1R2))/adc值 转化:实际电压V 5v*(adc值/4095)/(R2/(R1R2)) :老板再也不能 因为不会…

PointCloudLib-滤波模块(Filtering)-直通滤波

使用直通过滤器过滤点云 在本教程中,我们将学习如何沿着 指定维度 – 即,切断位于 或 在给定用户范围之外。 代码 首先,创建一个文件,比方说,在你最喜欢的 编辑器,并将以下内容放入其中:passthrough.cpp #include <iostream> #include <pcl/point_types.h&g…

推荐系统-FM模型

参考&#xff1a;推荐系统&#xff08;三&#xff09;Factorization Machines&#xff08;FM&#xff09;_svmmf-CSDN博客 一句话概括下FM&#xff1a;隐式向量特征交叉----类似embedding的思想 LR 如果利用LR做特征的二阶交叉&#xff0c;有&#xff1a; 但这个公式存在显著…

Open3D 点云的ISS关键点提取

目录 一、概述 1.1原理 1.2应用场景 1.3算法实现步骤 二、代码实现 2.1 完整代码 2.2关键函数 2.3关键点可视化 三、实现效果 3.1原始点云 3.2提取后点云 一、概述 1.1原理 ISS&#xff08;Intrinsic Shape Signatures&#xff09;关键点提取是一种常用于三维点云的…

【LLM-多模态】高效多模态大型语言模型综述

一、结论写在前面 模型规模的庞大及训练和推理成本的高昂&#xff0c;限制了MLLMs在学术界和工业界的广泛应用。因此&#xff0c;研究高效轻量级的MLLMs具有巨大潜力&#xff0c;特别是在边缘计算场景中。 论文深入探讨了高效MLLM文献的领域&#xff0c;提供了一个全面的视角…

Win10可用的VC6.0绿色版及辅助插件assist_X

VC6.0&#xff0c;作为微软的经典开发工具&#xff0c;承载着无数开发者的青春与回忆。它曾是Windows平台上软件开发的重要基石&#xff0c;为开发者们提供了稳定且强大的编程环境&#xff0c;尤其是其MFC&#xff08;Microsoft Foundation Classes&#xff09;库&#xff0c;为…

SSM宠物领养系统-计算机毕业设计源码08465

目 录 摘要 1 绪论 1.1课题背景及意义 1.2研究现状 1.3ssm框架介绍 1.3论文结构与章节安排 2 宠物领养系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 …

uni-push(2.0)常见问题,Android平台

将常用的网址一定要收藏在标签栏中&#xff0c;方便后期找&#xff0c;不然后期会很生气。 草料二维码&#xff0c;这个在线工具可以将打包生成的apk文件生成二维码&#xff0c;供测试人员测试。生成的apk只有五次下载机会&#xff0c;可点击链接后的一键上传&#xff0c;这样…

数据资产管理的艺术之道:深入探索如何在数据价值的最大化、个人隐私的严密保护以及企业持续发展的战略需求之间找到微妙的平衡

在数字化浪潮席卷全球的今天&#xff0c;数据已成为企业最宝贵的资产之一。从市场营销到产品研发&#xff0c;从客户服务到运营管理&#xff0c;数据无处不在&#xff0c;为企业提供了前所未有的洞察力和竞争力。然而&#xff0c;随着数据量的激增和数据类型的多样化&#xff0…

【Linux网络(一)初识计算机网络】

一、网络发展 1.发展背景 2.发展类型 二、网络协议 1.认识协议 2.协议分层 3.OSI七层模型 4.TCP/IP协议 三、网络传输 1.协议报头 2.局域网内的两台主机通信 3.跨网络的两台主机通信 四、网络地址 1.IP地址 2.MAC地址 一、网络发展 1.发展背景 计算机网络的发展…

七天速通javaSE:第三天 程序控制结构:顺序、选择、循环

文章目录 前言一、Scanner类1. hasNext()和hasNextLine()2.next()和nextLine()3. Scanner的其他用法 二、顺序结构三、选择结构1. if单选择结构2. if-else双选择结构3. if-else if多选择结构4. switch选择结构 四、循环结构1. while循环2.do while循环3. for循环&#xff08;常…

表单(forms)

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 在app1文件夹下创建一个forms.py文件&#xff0c;添加如下类代码&#xff1a; from django import forms class PersonForm(forms.Form): first_na…

C语言·动态内存管理

1. 为什么要有动态内存管理&#xff1f; 例1&#xff1a; //固定的向内存申请4个字节 int a 10;//申请连续的一块空间 int arr[10]; 这些数据一旦声明定义之后就会在内存中有一块空间&#xff0c;这些空间都是固定的&#xff0c;为了让内存使用更加灵活&#xff0c;这时我们…

【复旦邱锡鹏教授《神经网络与深度学习公开课》笔记】卷积

卷积经常用在信号处理中&#xff0c;用于计算信号的延迟累积。假设一个信号发射器每个时刻 t t t产生一个信号 x t x_t xt​&#xff0c;其信息的衰减率为 w k w_k wk​&#xff0c;即在 k − 1 k-1 k−1个时间步长后&#xff0c;信息为原来的 w k w_k wk​倍&#xff0c;时刻 …