Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl的简介

          TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、PPO是如何工作的PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout

Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

Evaluation

Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值

Optimization

Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

这个过程在下面的示意图中说明。

trl的安装

pip install trl

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

# imports
from datasets import load_dataset
from trl import SFTTrainer# get dataset
dataset = load_dataset("imdb", split="train")# get trainer
trainer = SFTTrainer("facebook/opt-350m",train_dataset=dataset,dataset_text_field="text",max_seq_length=512,
)# train
trainer.train()

(2)、如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")...# load trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,train_dataset=dataset,
)# train
trainer.train()

(3)、如何使用库中的PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)tokenizer = AutoTokenizer.from_pretrained('gpt2')# initialize trainer
ppo_config = PPOConfig(batch_size=1,
)# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")# get model response
response_tensor  = respond_to_batch(model, query_tensor)# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107798.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4.3 划分子网和构造超网

思维导图: 4.3.1 划分子网 **4.3 划分子网和构造超网笔记:** --- **4.3.1 划分子网** **1. 两级IP地址到三级IP地址的转变:** **关键点:** - **问题背景:** 早期的ARPANET对IP地址的设计存在不足: 1…

LeetCode 63. 不同路径 II

63. 不同路径 II 思路: 动态规划 dp[i][j] :表示从(0 ,0)出发,到(i, j) 有dp[i][j]条不同的路径 根据题意,只能向下或者向右移动一步,则dp[i][j] dp[i - 1][j] dp[i][j - 1] 但是…

uni-app--》基于小程序开发的电商平台项目实战(五)

🏍️作者简介:大家好,我是亦世凡华、渴望知识储备自己的一名在校大学生 🛵个人主页:亦世凡华、 🛺系列专栏:uni-app 🚲座右铭:人生亦可燃烧,亦可腐败&#xf…

IDEA spring-boot项目启动,无法加载或找到启动类问题解决

问题描述:找不到或无法加载主类 xxx.xxx.xxx.Classname 解决方案: 1.检查启动设置: 启动类所在包运行环境(一般选择默认即可)设置完成即可进行运行测试 2.如果第一步没有解决问题,试着第二步&#xff1a…

常见三维建模软件有哪些?各自的特点是什么?

常见的三维建模软件包括以下这些: 1. 3DS Max 3D Studio Max,简称3DS MAX,是当今世界上销售量最大的三维建模、动画及渲染软件。它的应用范围广泛,包括计算机游戏中的动画制作、影视片的特效制作等。3DS MAX的操作相对容易&#…

【学习笔记】RabbitMQ02:交换机,以及结合springboot快速开始

参考资料 RabbitMQ官方网站RabbitMQ官方文档噼咔噼咔-动力节点教程 文章目录 四、RabbitMQ :Exchange 交换机4.1 交换机类型4.2 扇形交换机 Fanout Exchange4.2.1 概念4.2.1 实例:生产者4.2.1.1 添加起步依赖4.2.1.2 配置文件4.2.1.3 JavaBean进行配置4.…

iMazing 3中文版功能介绍免费下载安装教程

iMazing 3中文版免费下载是一款兼容Win和Mac的iOS设备管理软件。iMazing 3能够将音乐、文件、消息和应用等数据从任何 iPhone、iPad 或 iPod 传输到 Mac 或 PC 上。 使用iMazing 3独特的 iOS 备份功能保证数据安全:设定自动无线备份时间并支持快照;将备份保存到外接驱动器和网…

17 - 并发容器的使用:识别不同场景下最优容器

在并发编程中,我们经常会用到容器。今天我要和你分享的话题就是:在不同场景下我们该如何选择最优容器。 1、并发场景下的 Map 容器 假设我们现在要给一个电商系统设计一个简单的统计商品销量 TOP 10 的功能。常规情况下,我们是用一个哈希表…

如何通过Photoshop将视频转换成GIF图片

一、应用场景 1、将视频转有趣动图发朋友圈 2、写CSDN无法上传视频,而可以用GIF动图替代 3、其他 二、实现步骤 1、打开Photoshop APP 2、点击文件——导入——视频帧到图层 3、选择视频文件 4、配置视频信息,按照图片提示配置完毕之后点击确定&…

C# Winform编程(4)多文档窗口(MDI)

多文档窗口(MDI) 添加菜单,IsMdiContainer设为True: From窗口添加菜单 Form1.cs using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using S…

snk-给github界面加一个有趣的动画

How to enable GitHub Actions on your Profile README for a snake-eating contribution graph 🐍 - DEV Community Platane/Platane (github.com) ① 创建New repository 名字和你自己的Github 用户名一样。 ② 创建之后,再这个Zero-coder仓库下创建…

学信息系统项目管理师第4版系列29_信息系统治理

1. IT治理 1.1. 描述组织采用有效的机制对信息技术和数据资源开发利用,平衡信息化发展和数字化转型过程中的风险,确保实现组织的战略目标的过程 1.2. 驱动因素 1.2.1. 信息孤岛 1.2.2. 信息资源整合目标空泛 1.3. 高质量IT治理因素 1.3.1. 良好的I…

Flask框架配置celery-[1]:flask工厂模式集成使用celery,可在异步任务中使用flask应用上下文,即拿即用,无需更多配置

一、概述 1、celery框架和flask框架在运行时,是在不同的进程中,资源是独占的。 2、celery异步任务如果想使用flask中的功能,如orm,是需要在flask应用上下文管理器中执行orm操作的 3、使用celery是需要使用到中间件的&#xff0…

内容分发网络CDN分布式部署真的可以加速吗?原理是什么?

Cdn快不快?她为什么会快?同样的带宽为什么她会快?原理究竟是什么,同学们本着普及知识的想法,我了解的不是很深入,适合小白来看我的帖子,如果您是大佬还请您指正错误的地方,先谢谢大佬…

nodejs基于vue网上考勤系统

本网上考勤系统是针对目前考勤的实际需求, 采用计算机系统来管理信息,取代人工管理模式,查询便利,信息准确率高,节省了开支,提高了工作的效率。 本网上考勤系统主要包括个人中心、员工请假管理、员工考勤管…

asp.net酒店管理系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net酒店管理系统是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言开发 asp.net 酒店管理系统1 二、功能介绍 …

如何使用FME开发自动化分析报告功能

目录 前言 一、使用的技术栈 二、技术难点解析 1.专题图 2.WORD文档实现 2.1 动态标题 2.3动态表格和文本 2.3专题图插入 三、完成NewGIS部署 四、模板总览图 总结 前言 一个标准项目分析报告需要需要包括3个方面: 文本叙述,主要体现在对某项专项数据的…

Radius OTP完成堡垒机登录认证 安当加密

Radius OTP(One-Time Password)是一种用于身份验证的协议,它通过向用户发送一个一次性密码来验证用户的身份。使用Radius OTP可以实现堡垒机登录,以下是一些实现步骤: 1、安装Radius服务器 首先需要安装Radius服务器…

数字化转型“同群效应”(2000-2022年)

参照霍春辉等(2023)的做法,团队对上市公司-数字化转型“同群效应”进行测算。将同行业、同省的其他企业定义为同群企业,并以该群体数字化转型程度均值、中位数作为衡量 一、数据介绍 数据名称:数字化转型“同群效应”…

c++视觉检测------Shi-Tomasi 角点检测

Shi-Tomasi 角点检测 :goodFeaturesToTrack() goodFeaturesToTrack() 函数是 OpenCV 中用于角点检测的功能函数。它的主要作用是检测图像中的良好特征点,通常用于计算机视觉任务中的光流估算、目标跟踪等。 函数签名: void goodFeaturesTo…