强化学习-RLHF-PPO入门

一、定义

  1. 强化学习微调分类
  2. RM模型 数据集格式
  3. 训练流程
  4. Reward 模型训练流程(分类模型,积极为1,消极为0) AutoModelForSequenceClassification
  5. Reward 模型训练案例
  6. PPO模型训练流程
  7. PPO模型训练案例

二、实现

  1. 强化学习微调分类
    RLHF:基于人类反馈对语言模型进行强化学习, 分两步:
    1. RM (Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好。
    2 RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本.
    DPO(Direct Preference Optimization): 直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好。
    RLHF主要是进行对齐微调, 目标是将大语言模型的行为与人类的价值观或偏好对齐。
    PPO: (Proximal Policy Optimization,近端策略优化)是一种在强化学习领域广泛使用的算法.

  2. RM模型 数据集格式

{conversations:  [0: {from:  
"human",value:  "国会的转发 美国国会由众议院和参议院组成,每两年换届一次(参议员任期为6年,但参议院选举是错位的。是更常见地转发国会议员还是来自国会外部?"}],
chosen:  {from:  "gpt",value:  "计算推文的政党边际概率,我们可以使用以下代码这表明大多数转发不是来自国会议员,而是来自国会之外。"},
rejected:  {from:  "gpt",value:  "回答问题的第(计算转发国会议员或来自国会以外的人的边际概率"}}

其中chosen 代表是好的回答, rejected代表的是不好的回答

  1. 训练流程
    在这里插入图片描述
    训练reward Model---->PPO模型

  2. Reward 模型训练流程(激励模型为深度学习模型)
    数据处理:

def preprocess_function(examples):new_examples = {"input_ids_chosen": [],"attention_mask_chosen": [],"input_ids_rejected": [],"attention_mask_rejected": [],}for chosen, rejected in zip(examples["chosen"], examples["rejected"]):tokenized_chosen = tokenizer(chosen)tokenized_rejected = tokenizer(rejected)new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])return new_examples

训练求损失:AutoModelForSequenceClassification 分类模型

model = AutoModelForSequenceClassification.from_pretrained(model_config.model_name_or_path, num_labels=1, **model_kwargs
)
def compute_loss(self,model: Union[PreTrainedModel, nn.Module],inputs: Dict[str, Union[torch.Tensor, Any]],return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:if not self.use_reward_data_collator:warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding,"" if you are using a custom data collator make sure you know what you are doing or"" implement your own compute_loss method.")rewards_chosen = model(input_ids=inputs["input_ids_chosen"],attention_mask=inputs["attention_mask_chosen"],return_dict=True,)["logits"]rewards_rejected = model(input_ids=inputs["input_ids_rejected"],attention_mask=inputs["attention_mask_rejected"],return_dict=True,)["logits"]# calculate loss, optionally modulate with marginif "margin" in inputs:loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()else:loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()if return_outputs:return loss, {"rewards_chosen": rewards_chosen,"rewards_rejected": rewards_rejected,}return loss
  1. Reward 模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py

  2. PPO模型训练流程

在这里插入图片描述

步骤:
1. 语言模型预测
2. 激活模型评估(分类模型),1 代表积极,0 代表消极
3. PPO算法优化。
数据:

def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sample
  1. 加载模型, 参考模型(参考模型可以为None)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# Get response from gpt2    待训练的模型响应,参考模型响应
response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
)
batch["response"] = tokenizer.batch_decode(response_tensors)
batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)
  1. 激活模型评估(分类模型),1 代表积极,0 代表消极
2. 获取激励值
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)                      #激励函数
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]   #激励值
  1. PPO算法优化。
#   问题query  、  模型响应   、激励值
#其中上图优化模块,均在step 方法中实现。
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
step内部:all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(self.model,queries,responses,model_inputs,response_masks=response_masks,return_logits=full_kl_penalty,)with self.optional_peft_ctx():ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(self.model if self.is_peft_model else self.ref_model,queries,responses,model_inputs,return_logits=full_kl_penalty,)
  1. PPO模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/33977.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python,ipython 和 jupyter notebook 之间的关系

python,ipython 和 jupyter notebook 之间的关系 文章目录 python,ipython 和 jupyter notebook 之间的关系1. Python2. IPython3. Jupyter Notebook启动 Jupyter Notebook 关系总结 Python、IPython 和 Jupyter Notebook 是相互关联但具有不同功能的工具…

机器学习之数学基础(七)~过拟合(over-fitting)和欠拟合(under-fitting)

目录 1. 过拟合与欠拟合 1.1 Preliminary concept 1.2 过拟合 over-fitting 1.3 欠拟合 under-fitting 1.4 案例解析:黑天鹅 1. 过拟合与欠拟合 1.1 Preliminary concept 误差 经验误差:模型对训练集数据的误差。泛化误差:模型对测试…

【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇

系列文章目录 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构、训练数据和纹理迁移进行了更详细的介绍。【可控图像生成系列论文&…

RabbitMQ 消息传递

消息何去何从 mandatory和immediate是channel.basicPublish方法中的两个参数,他们都有当消息传递过程中不可达目的地时将消息返回给生产者的功能。RabbitMQ提供的备份交换器可以将未能被交换器路由的消息(没有绑定队列或者没有匹配的绑定)存…

转行供应链—安全库存

安全库存(又称保险库存) 安全库存的定义: 安全库存是一种缓冲库存,用于应对需求波动和供应链不确定性。这些不确定因素可能包括订货期间需求的增长、到货延期等情况。其目的是在供应链出现意外问题时,确保企业能够持续…

学习C++第三天——对引用的深入了解

引用 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开辟内存空 间,它和它引用的变量共用同一块内存空间。 一个变量可以有多个引用: int main() {//一个变量可以有多个引用int a 10;int& b …

OpenAI 收购桌面实时协作公司 Multi;iOS 18 开放 iPhone 镜像测试丨RTE 开发者日报 Vol.231

开发者朋友们大家好: 这里是 「RTE 开发者日报」 ,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE(Real-Time Engagement) 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

c++ 设计模式 的课本范例(上)

( 0 ) 这里补充面向对象设计的几个原则: 开闭原则OCP : 面向增补开放,面向代码修改关闭。其实反映到代码设计上就是类的继承,通过继承与多态,可以不修改原代码,又增加新的类似的功能。 依赖倒置原则 Depend…

如何从iPhone恢复错误删除的照片

嘿,iPhone 用户!作为一名苹果专业人士,我见过相当多的“哎呀,我删除了它!”的时刻。今天,我在这里指导您完成从iPhone中恢复那些珍贵的,错误删除的照片的迷宫。坐下来,拿起你的设备&…

分压电路 ADC计算电压 【老板再也不能开除我了 】

经典分压电路 一个电压过来 adc这里的电压等于: 如是12位adc 那么他最大值就是4095 如参考电压是5v 则:5v/4095 实际电压V*(R2/(R1R2))/adc值 转化:实际电压V 5v*(adc值/4095)/(R2/(R1R2)) :老板再也不能 因为不会…

11.利用RTC制作实时时钟

RTC 从配置上分两大部分—时钟的配置、定时器的配置 时钟的配置:可以直接访问 直接由RCC的BDCR来配置时钟:时钟源的选择 定时器的配署:不可以直接访问,因为定时器相关的寄存器在备份区域 1、使能备份区域访问— PWREN、BKPEN …

PointCloudLib-滤波模块(Filtering)-直通滤波

使用直通过滤器过滤点云 在本教程中,我们将学习如何沿着 指定维度 – 即,切断位于 或 在给定用户范围之外。 代码 首先,创建一个文件,比方说,在你最喜欢的 编辑器,并将以下内容放入其中:passthrough.cpp #include <iostream> #include <pcl/point_types.h&g…

推荐系统-FM模型

参考&#xff1a;推荐系统&#xff08;三&#xff09;Factorization Machines&#xff08;FM&#xff09;_svmmf-CSDN博客 一句话概括下FM&#xff1a;隐式向量特征交叉----类似embedding的思想 LR 如果利用LR做特征的二阶交叉&#xff0c;有&#xff1a; 但这个公式存在显著…

在分布式系统中,Erlang 的错误处理和容错机制是如何实现的,又面临哪些挑战?

Erlang是一种被广泛用于构建高可用、容错性强的分布式系统的编程语言。它提供了一些内建的错误处理和容错机制来处理系统中的错误和故障。 下面是Erlang中常用的错误处理和容错机制&#xff1a; 进程监控&#xff08;Process Monitoring&#xff09;&#xff1a;Erlang的进程是…

case when 使用——mysql sql

case when的使用方法主要有两种&#xff1a; 第一种&#xff1a; UPDATE USER SET USERNAME CASE WHEN ID 1 THEN USERNAME1 WHEN ID 2 THEN USERNAME2 WHEN ID 3 THEN USERNAME3 END , PASSWORD CASE WHEN ID 1 THEN PASSWORD1 WHEN ID 2 THEN PASSWORD2 WHEN ID…

Open3D 点云的ISS关键点提取

目录 一、概述 1.1原理 1.2应用场景 1.3算法实现步骤 二、代码实现 2.1 完整代码 2.2关键函数 2.3关键点可视化 三、实现效果 3.1原始点云 3.2提取后点云 一、概述 1.1原理 ISS&#xff08;Intrinsic Shape Signatures&#xff09;关键点提取是一种常用于三维点云的…

【LLM-多模态】高效多模态大型语言模型综述

一、结论写在前面 模型规模的庞大及训练和推理成本的高昂&#xff0c;限制了MLLMs在学术界和工业界的广泛应用。因此&#xff0c;研究高效轻量级的MLLMs具有巨大潜力&#xff0c;特别是在边缘计算场景中。 论文深入探讨了高效MLLM文献的领域&#xff0c;提供了一个全面的视角…

Win10可用的VC6.0绿色版及辅助插件assist_X

VC6.0&#xff0c;作为微软的经典开发工具&#xff0c;承载着无数开发者的青春与回忆。它曾是Windows平台上软件开发的重要基石&#xff0c;为开发者们提供了稳定且强大的编程环境&#xff0c;尤其是其MFC&#xff08;Microsoft Foundation Classes&#xff09;库&#xff0c;为…

计算机网络:408考研|湖科大教书匠|原理参考系统I|学习笔记

系列目录 计算机网络总纲领 计算机网络特殊考点 目录 系列目录更新日志数据链路层(Data Link Layer)一、基本概念二、三个重要问题三、 &#x1f31f;点对点协议(PPP, Point-to-Point Protocol)四、 以太网五、802.11 无线局域网(简称Wi-Fi) 物理层(Physical Layer)一、传输方…

SSM宠物领养系统-计算机毕业设计源码08465

目 录 摘要 1 绪论 1.1课题背景及意义 1.2研究现状 1.3ssm框架介绍 1.3论文结构与章节安排 2 宠物领养系统系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1 数据流程 3.3.2 业务流程 2.3 系统功能分析 2.3.1 功能性分析 2.3.2 非功能性分析 2.4 系统用例分析 …