RLHF学习

整体流程

三个步骤分解:

  1. 预训练一个语言模型 (LM) ;
  2. 聚合问答数据并训练一个奖励模型 (Reward Model,RM) ;
  3. 用强化学习 (RL) 方式微调 LM。

在这里插入图片描述

在这里插入图片描述

RW

RM 的训练是 RLHF 区别于旧范式的开端。这一模型接收一系列文本并返回一个标量奖励,数值上对应人的偏好。我们可以用端到端的方式用 LM 建模,或者用模块化的系统建模 (比如对输出进行排名,再将排名转换为奖励) 。这一奖励数值将对后续无缝接入现有的 RL 算法至关重要。

  • 关于模型选择方面:
    RM 可以是另一个经过微调的 LM,也可以是根据偏好数据从头开始训练的 LM。例如 Anthropic 提出了一种特殊的预训练方式,即用偏好模型预训练 (Preference Model Pretraining,PMP) 来替换一般预训练后的微调过程。因为前者被认为对样本数据的利用率更高。但对于哪种 RM 更好尚无定论。

  • 过程:
    在这里插入图片描述

  • Bradley-Terry(BT)模型是一个常见选择(在可以获得多个排序答案的情况下,Plackett-Luce 是更一般的排序模型)

  • **排序损失:**在最后一层 transformer 层后添加一个线性层以获得奖励值的标量预测。为了确保奖励函数具有较低的方差,之前的工作会对奖励进行归一化
    在这里插入图片描述

RW代码

from dataclasses import dataclass, field
from typing import Optionalimport tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfigfrom trl import RewardConfig, RewardTrainer, is_xpu_availabletqdm.pandas()@dataclass
class ScriptArguments:model_name: str = "facebook/opt-350m""""the model name"""dataset_name: str = "Anthropic/hh-rlhf""""the dataset name"""dataset_text_field: str = "text""""the text field of the dataset"""eval_split: str = "none""""the dataset split to evaluate on; default to 'none' (no evaluation)"""load_in_8bit: bool = False"""load the model in 8 bits precision"""load_in_4bit: bool = False"""load the model in 4 bits precision"""trust_remote_code: bool = True"""Enable `trust_remote_code`"""reward_config: RewardConfig = field(default_factory=lambda: RewardConfig(output_dir="output",per_device_train_batch_size=64,num_train_epochs=1,gradient_accumulation_steps=16,gradient_checkpointing=True,gradient_checkpointing_kwargs={"use_reentrant": False},learning_rate=1.41e-5,report_to="tensorboard",remove_unused_columns=False,optim="adamw_torch",logging_steps=500,evaluation_strategy="no",max_length=512,))use_peft: bool = False"""whether to use peft"""peft_config: Optional[LoraConfig] = field(default_factory=lambda: LoraConfig(r=16,lora_alpha=16,bias="none",task_type="SEQ_CLS",modules_to_save=["scores"],),)args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)# Copy the model to each devicedevice_map = ({"": f"xpu:{Accelerator().local_process_index}"}if is_xpu_available()else {"": Accelerator().local_process_index})
else:device_map = Nonequantization_config = Nonemodel = AutoModelForSequenceClassification.from_pretrained(args.model_name,quantization_config=quantization_config,device_map=device_map,trust_remote_code=args.trust_remote_code,num_labels=1,
)# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):new_examples = {"input_ids_chosen": [],"attention_mask_chosen": [],"input_ids_rejected": [],"attention_mask_rejected": [],}for chosen, rejected in zip(examples["chosen"], examples["rejected"]):tokenized_chosen = tokenizer(chosen)tokenized_rejected = tokenizer(rejected)new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])return new_examples# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(preprocess_function,batched=True,num_proc=4,
)
train_dataset = train_dataset.filter(lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_lengthand len(x["input_ids_rejected"]) <= args.reward_config.max_length
)if args.eval_split == "none":eval_dataset = None
else:eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)eval_dataset = eval_dataset.map(preprocess_function,batched=True,num_proc=4,)eval_dataset = eval_dataset.filter(lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_lengthand len(x["input_ids_rejected"]) <= args.reward_config.max_length)# Step 4: Define the LoraConfig
if args.use_peft:peft_config = args.peft_config
else:peft_config = None# Step 5: Define the Trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,args=args.reward_config,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,
)trainer.train()

RLHF

  • 动手学强化学习: https://hrl.boyuai.com/chapter/2/actor-critic%E7%AE%97%E6%B3%95

让我们首先将微调任务表述为 RL 问题。

  • 首先,该 策略 (policy) 是一个接受提示并返回一系列文本 (或文本的概率分布) 的 LM。
  • 这个策略的 行动空间 (action space) 是 LM 的词表对应的所有词元 (一般在 50k 数量级)
  • 观察空间 (observation space) 是可能的输入词元序列,也比较大 (词汇量 ^ 输入标记的数量) 。
  • 奖励函数 是偏好模型和策略转变约束 (Policy shift constraint) 的结合。
    在这里插入图片描述

在这里插入图片描述

  • KL散度这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值。

可视化进度条的一种方法:

with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(episode):if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)

策略梯度

  • 基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。

AC算法

  • 基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数

  • Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

  • Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是 Q ( s , a ) Q(s,a) Q(s,a)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了函数 A A A,我们称之为优势函数(advantage function)

Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络)

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。

在这里插入图片描述
在这里插入图片描述

PPO

PPO惩罚

PPO-惩罚(PPO-Penalty)用拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。
在这里插入图片描述

PPO截断

在这里插入图片描述

  • 对于连续动作,让策略网络输出连续动作高斯分布(Gaussian distribution)的均值和标准差。后续的连续动作则在该高斯分布中采样得到。
PPO的训练中存在的问题

PPO会找捷径,只要有机会,PPO 算法就会利用这些缺陷。

  1. 显然,当从概率低于 SFT 模型的策略中采样令牌时,这将导致负 KL 惩罚。但平均而言,它将是正的,否则您将无法从策略中正确采样。使用 KL 惩罚项是为了推动模型的输出保持接近基本策略的输出。一般来说,KL 散度衡量两个分布之间的距离,并且始终为正值。

  2. 某些生成策略可以强制生成某些token或抑制某些token。例如,当批量生成完成的序列时,会进行填充;当设置最小长度时,EOS 令牌会被抑制。该模型可以为那些导致负 KL 的标记分配非常高或非常低的概率。当 PPO 算法针对奖励进行优化时,它会追逐这些负面惩罚,从而导致不稳定。

    • 生成响应时需要小心,我们建议在采用更复杂的生成方法之前始终先使用简单的采样策略。
  3. 损失偶尔会出现峰值,这可能会导致进一步的不稳定。

  4. 字符串的重复会导致奖励的突然增加。

DPO

与以往的 RLHF 方法(先学习一个奖励函数,然后通过强化学习优化)不同,我们的方法跳过了奖励建模步骤,直接使用偏好数据优化语言模型。

  • 我们的核心观点是利用从奖励函数到最优策略的解析映射,将对奖励函数的损失转化为对策略的损失。这种变量转换的方法使我们能够跳过显式的奖励建模步骤,同时仍然在现有的人类偏好模型(如 Bradley-Terry 模型)下进行优化。实质上,策略网络既代表语言模型,又代表奖励。

在这里插入图片描述
在这里插入图片描述

RLHF开源工具

  • TRL
  • RL4LM

TRL实践

  • demo.py
# 0. imports
import torch
from transformers import GPT2Tokenizerfrom trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)# 4. generate model response
generation_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/650480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣hot100 最小栈 变种栈

Problem: 155. 最小栈 文章目录 思路&#x1f496; Stack 自定义 Node&#x1f37b; Code 思路 &#x1f469;‍&#x1f3eb; 甜姨 &#x1f496; Stack 自定义 Node 时间复杂度: O ( 1 ) O(1) O(1) 空间复杂度: O ( n ) O(n) O(n) &#x1f37b; Code class MinS…

轻松打卡:使用Spring Boot和Redis Bitmap构建高效签到系统【redis实战 四】

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 轻松打卡&#xff1a;使用Spring Boot和Redis Bitmap构建高效签到系统【redis实战 四】 引言(redis实战)前言回顾bitmap基本概念核心特性使用场景 为什么使用redis中的bitmap实现&#xff1f;1. 存储效…

紫光展锐M6780丨超分辨率技术——画质重构还原经典

上一期&#xff0c;我们揭秘了让画质更加炫彩的AI-PQ技术。面对分辨率较低的老电影&#xff0c;光有高饱和度的色彩是不够的&#xff0c;如何能够提高视频影像的分辨率&#xff0c;使画质更加清晰&#xff0c;实现老片新看&#xff1f; 本期带大家揭晓紫光展锐首颗AI8K超高清智…

【分布式技术专题】「分布式技术架构」 探索Tomcat技术架构设计模式的奥秘(Server和Service组件原理分析)

探索Tomcat技术架构设计模式的奥秘 Tomcat系统架构分析Tomcat 整体结构Tomcat总体结构图以 Service 作为“婚姻”1) Service 接口方法列表 2) StandardService 的类结构图方法列表 3) StandardService. SetContainer4) StandardService. addConnector 以 Server 为“居”1) Ser…

etcd技术解析:构建高可用分布式系统的利器

1. 引言 随着云原生技术的兴起&#xff0c;分布式系统的构建变得愈发重要。etcd作为一个高可用的分布式键值存储系统&#xff0c;在这个领域发挥着至关重要的作用。本文将深入探讨etcd的技术细节&#xff0c;以及如何利用它构建高可用的分布式系统。 2. etcd简介 etcd是一个开…

JVM系列-9.性能调优

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱吃芝士的土豆倪&#xff0c;24届校招生Java选手&#xff0c;很高兴认识大家&#x1f4d5;系列专栏&#xff1a;Spring原理、JUC原理、Kafka原理、分布式技术原理、数据库技术、JVM原理&#x1f525;如果感觉博主的文…

大数据安全 | 期末复习(中)

文章目录 &#x1f4da;感知数据安全⭐️&#x1f407;传感器概述&#x1f407;传感器的静态特性&#x1f407;调制方式&#x1f407;换能攻击&#x1f407;现有防护策略 &#x1f4da;AI安全⭐️&#x1f407;智能语音系统——脆弱性&#x1f407;攻击手段&#x1f407;AI的两…

探索IOC和DI:解密Spring框架中的依赖注入魔法

IOC与DI的详细解析 IOC详解1 bean的声明2 组件扫描 DI详解 IOC详解 1 bean的声明 IOC控制反转&#xff0c;就是将对象的控制权交给Spring的IOC容器&#xff0c;由IOC容器创建及管理对象。IOC容器创建的对象称为bean对象。 要把某个对象交给IOC容器管理&#xff0c;需要在类上…

【深度学习】sdxl中的 text_encoder text_encoder_2 区别

镜像问题是&#xff1a;https://editor.csdn.net/md/?articleId135867689 代码仓库&#xff1a; https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main 截图&#xff1a; 为什么有两个CLIP编码器 text_encoder 和 text_encoder_2 &#xff1f; 在…

那些年与指针的爱恨情仇(一)---- 指针本质及其相关性质用法

关注小庄 顿顿解馋 (≧∇≦) 引言&#xff1a; 小伙伴们在学习c语言过程中是否因为指针而困扰&#xff0c;指针简直就像是小说女主&#xff0c;它逃咱追&#xff0c;我们插翅难飞…本篇文章让博主为你打理打理指针这个傲娇鬼吧~ 本节我们将认识到指针本质&#xff0c;何为指针和…

RHCE项目:使用LNMP搭建私有云存储

目录 一、准备工作 1、关闭防火墙、安全软件 2、搭建LNMP环境 3、上传软件 4、设置nextcloud安装命令权限 二、数据库 1、设置数据库 2、重启数据库 三、配置nginx 四、安装nextcloud 五、内网穿透 1、创建内网映射 2、linux系统安装花生壳客户端 3、重新打开浏览…

林浩然与极限的“无穷”约会

林浩然与极限的“无穷”约会 Lin Haoran’s Encounter with the Mathematical “Infinity” 在数学王国里&#xff0c;有一位名叫林浩然的大侠&#xff0c;他的江湖就是高等数学的殿堂。而他要挑战的终极Boss&#xff0c;便是那个既神秘又顽皮的“极限”。 In the kingdom of …

C# .Net6搭建灵活的RestApi服务器

1、准备 C# .Net6后支持顶级语句&#xff0c;更简单的RestApi服务支持&#xff0c;可以快速搭建一个极为简洁的Web系统。推荐使用Visual Studio 2022&#xff0c;安装"ASP.NET 和Web开发"组件。 2、创建工程 关键步骤如下&#xff1a; 包添加了“Newtonsoft.Json”&…

【Git】项目管理笔记

文章目录 本地电脑初始化docker报错.gitignoregit loggit resetgit statusgit ls-filesgit rm -r -f --cached拉取仓库文件更新本地的项目报错处理! [rejected] master -> master (fetch first)gitgitee.com: Permission denied (publickey).error: remote origin already e…

BACnet转IEC104网关BE104

随着电力系统信息化建设和数字化转型的进程不断加速&#xff0c;对电力能源的智能化需求也日趋增强。健全稳定的智慧电力系统能够为工业生产、基础设施建设以及国防建设提供稳定的能源支持。在此背景下&#xff0c;高性能的工业电力数据传输解决方案——协议转换网关应运而生&a…

研发日记,Matlab/Simulink避坑指南(六)——字节分割Bug

文章目录 前言 背景介绍 问题描述 分析排查 解决方案 总结归纳 前言 见《研发日记&#xff0c;Matlab/Simulink避坑指南&#xff08;一&#xff09;——Data Store Memory模块执行时序Bug》 见《研发日记&#xff0c;Matlab/Simulink避坑指南(二)——非对称数据溢出Bug》…

Arduino开发实例-DRV8833电机驱动器控制直流电机

DRV8833电机驱动器控制直流电机 文章目录 DRV8833电机驱动器控制直流电机1、DRV8833电机驱动器介绍2、硬件接线图3、代码实现DRV8833 使用 MOSFET,而不是 BJT。 MOSFET 的压降几乎可以忽略不计,这意味着几乎所有来自电源的电压都会传递到电机。 这就是为什么 DRV8833 不仅比基…

寒假思维训练day15 牛客练习赛121

牛客练习赛ABCD题解&#xff0c;更新一个题解作为今天的任务收尾。 寒假思维训练day15 摘要&#xff1a; Part1&#xff1a;B题&#xff0c;B-You Brought Me A Gentle Breeze on the Field_牛客练习赛121 (nowcoder.com) Part2: C题&#xff0c;C-氧气少年的水滴 2_牛客练…

职言圈 | 小伙年薪95w,女朋友父母却爱搭不理,如今上岸国家电网,人不到,叔叔阿姨吃饭都不动筷子...

在职场中&#xff0c;每个人都经历着不同的起伏和挑战。有时候&#xff0c;我们会面临着职业生涯上的起伏&#xff0c;但正是这些经历让我们成长&#xff0c;让我们更加坚韧。 就像这位网友一样&#xff0c;他在过去的一年里经历了美团L8的职位和年薪95W&#xff0c;却面临着女…

Sentinel解密:SlotChain中的SLot大揭秘

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 Sentinel解密&#xff1a;SlotChain中的SLot大揭秘 前言SlotChain简介&#xff1a;Sentinel的第一道防线入场仪式&#xff1a;SlotChain中的初始化SlotSlotChain的执行流程&#xff1a;从规则解析到流…