Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl的简介

          TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、PPO是如何工作的PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout

Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

Evaluation

Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值

Optimization

Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

这个过程在下面的示意图中说明。

trl的安装

pip install trl

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

# imports
from datasets import load_dataset
from trl import SFTTrainer# get dataset
dataset = load_dataset("imdb", split="train")# get trainer
trainer = SFTTrainer("facebook/opt-350m",train_dataset=dataset,dataset_text_field="text",max_seq_length=512,
)# train
trainer.train()

(2)、如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")...# load trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,train_dataset=dataset,
)# train
trainer.train()

(3)、如何使用库中的PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)tokenizer = AutoTokenizer.from_pretrained('gpt2')# initialize trainer
ppo_config = PPOConfig(batch_size=1,
)# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")# get model response
response_tensor  = respond_to_batch(model, query_tensor)# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107798.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4.3 划分子网和构造超网

思维导图: 4.3.1 划分子网 **4.3 划分子网和构造超网笔记:** --- **4.3.1 划分子网** **1. 两级IP地址到三级IP地址的转变:** **关键点:** - **问题背景:** 早期的ARPANET对IP地址的设计存在不足: 1…

LeetCode 63. 不同路径 II

63. 不同路径 II 思路: 动态规划 dp[i][j] :表示从(0 ,0)出发,到(i, j) 有dp[i][j]条不同的路径 根据题意,只能向下或者向右移动一步,则dp[i][j] dp[i - 1][j] dp[i][j - 1] 但是…

第五十九章 学习常用技能 - 将数据从一个数据库移动到另一个数据库

文章目录 第五十九章 学习常用技能 - 将数据从一个数据库移动到另一个数据库 第五十九章 学习常用技能 - 将数据从一个数据库移动到另一个数据库 如果需要将数据从一个数据库移动到另一个数据库,请执行以下操作: 识别包含数据及其索引的Global。 如果…

Android 13.0 首次开机进入Launcher3前黑屏几秒的几种情况问题的总结

1.概述 在13.0的系统产品定制化工作中,对于系统开发确实有些难度,特别是在开机阶段遇到的问题,比如开机动画播放完毕进入锁屏界面黑屏几秒然后进入 锁屏界面,这就需要根据开机日志来分析问题所在,在工作中遇到的几种黑屏情况做下记录,然后供以后解决问题做参考 2.首次开…

Java String之正则表达式

Java String之正则表达式 导言 最近做项目时,遇到了限制输入字符格式的问题,采用了Java String的正则表达式,下面针对正则表达式使用进行概述 正则表达式 正则表达式类似可以通俗的理解为字符模板,通过符号的方式进行表述&…

uni-app--》基于小程序开发的电商平台项目实战(五)

🏍️作者简介:大家好,我是亦世凡华、渴望知识储备自己的一名在校大学生 🛵个人主页:亦世凡华、 🛺系列专栏:uni-app 🚲座右铭:人生亦可燃烧,亦可腐败&#xf…

Fetch与Axios数据请求

什么是Polyfill? Polyfill是一个js库,主要抚平不同浏览器之间对js实现的差异。比如,html5的storage(session,local), 不同浏览器,不同版本,有些支持,有些不支持。Polyfill(Polyfill有很多,在Gi…

IDEA spring-boot项目启动,无法加载或找到启动类问题解决

问题描述:找不到或无法加载主类 xxx.xxx.xxx.Classname 解决方案: 1.检查启动设置: 启动类所在包运行环境(一般选择默认即可)设置完成即可进行运行测试 2.如果第一步没有解决问题,试着第二步&#xff1a…

常见三维建模软件有哪些?各自的特点是什么?

常见的三维建模软件包括以下这些: 1. 3DS Max 3D Studio Max,简称3DS MAX,是当今世界上销售量最大的三维建模、动画及渲染软件。它的应用范围广泛,包括计算机游戏中的动画制作、影视片的特效制作等。3DS MAX的操作相对容易&#…

【学习笔记】RabbitMQ02:交换机,以及结合springboot快速开始

参考资料 RabbitMQ官方网站RabbitMQ官方文档噼咔噼咔-动力节点教程 文章目录 四、RabbitMQ :Exchange 交换机4.1 交换机类型4.2 扇形交换机 Fanout Exchange4.2.1 概念4.2.1 实例:生产者4.2.1.1 添加起步依赖4.2.1.2 配置文件4.2.1.3 JavaBean进行配置4.…

iMazing 3中文版功能介绍免费下载安装教程

iMazing 3中文版免费下载是一款兼容Win和Mac的iOS设备管理软件。iMazing 3能够将音乐、文件、消息和应用等数据从任何 iPhone、iPad 或 iPod 传输到 Mac 或 PC 上。 使用iMazing 3独特的 iOS 备份功能保证数据安全:设定自动无线备份时间并支持快照;将备份保存到外接驱动器和网…

Halcon Camera-calibration 相关算子(二)

(1) set_calib_data( : : CalibDataID, ItemType, ItemIdx, DataName, DataValue : ) 功能:在标定数据模型中设置数据。 控制输入参数1:CalibDataID:标定数据模型句柄; 控制输入参数2:ItemType:标定数据…

HFSS中激励方式学习笔记(总)

HFSS中激励方式 文章目录 HFSS中激励方式波端口激励(Wave Port)集总端口激励(Lumped Port)floquet端口激励(floquet Port)入射波激励(Incident Port)电压源激励(Voltage …

17 - 并发容器的使用:识别不同场景下最优容器

在并发编程中,我们经常会用到容器。今天我要和你分享的话题就是:在不同场景下我们该如何选择最优容器。 1、并发场景下的 Map 容器 假设我们现在要给一个电商系统设计一个简单的统计商品销量 TOP 10 的功能。常规情况下,我们是用一个哈希表…

如何通过Photoshop将视频转换成GIF图片

一、应用场景 1、将视频转有趣动图发朋友圈 2、写CSDN无法上传视频,而可以用GIF动图替代 3、其他 二、实现步骤 1、打开Photoshop APP 2、点击文件——导入——视频帧到图层 3、选择视频文件 4、配置视频信息,按照图片提示配置完毕之后点击确定&…

c#调用CUDA执行YOLOV5对象检测

c#使用调YOLOV5对象检测,并调用CUDA进行计算 1.CUDA版本11.2 2.cuDNN用cudnn-windows-x86_64-8.9.3.28_cuda11-archive 记得把压缩包的三个文件夹放到cuda根目录下覆盖 3.Microsoft.ML.OnnxRuntime.Gpu要使用1.13.1,如果版本太新,SessionOptions会报…

C# Winform编程(4)多文档窗口(MDI)

多文档窗口(MDI) 添加菜单,IsMdiContainer设为True: From窗口添加菜单 Form1.cs using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using S…

snk-给github界面加一个有趣的动画

How to enable GitHub Actions on your Profile README for a snake-eating contribution graph 🐍 - DEV Community Platane/Platane (github.com) ① 创建New repository 名字和你自己的Github 用户名一样。 ② 创建之后,再这个Zero-coder仓库下创建…

从javascript到vue再到react的演变

当提到前端开发中的框架时,JavaScript、Vue.js和React.js是三个最常见的名词。它们代表了Web开发中不同的技术选择和演变过程。本文将探讨JavaScript从原生到Vue.js再到React.js的演变,以及每个阶段的特点和优势。 JavaScript: 动态语言的基础 JavaScr…

学信息系统项目管理师第4版系列29_信息系统治理

1. IT治理 1.1. 描述组织采用有效的机制对信息技术和数据资源开发利用,平衡信息化发展和数字化转型过程中的风险,确保实现组织的战略目标的过程 1.2. 驱动因素 1.2.1. 信息孤岛 1.2.2. 信息资源整合目标空泛 1.3. 高质量IT治理因素 1.3.1. 良好的I…