RLHF几大常用框架实践对比(trlx、deepspeedchat、colossalaichat)

原文:RLHF几大常用框架实践对比(trlx、deepspeedchat、colossalaichat) - 知乎

目录

收起

一、RLHF的作用

二、实践效果

三、怎么做

1、框架

2、算法

3、数据

4、调参

一、RLHF的作用

从InstructGPT的论文中看,RLHF目的就是为了让模型输出的结果能和人类对齐。而所谓对齐,体现在三点:

  • 有用:即遵守指令的能力
  • 诚实:不容易胡说八道
  • 安全:不容易生成不合法的、有害、有毒的信息

RLHF在这篇论文中,我们都知道分为三个步骤,包括SFT(微调模型)、RM(训练回报模型或者叫偏好模型)、RL(强化学习)。那么只靠SFT能做到对齐这件事吗?应该可以做到一部分,现在网上大多数流行的开源模型基本上也止步到SFT这个步骤。其实SFT其实也展现出了很不错的性能,但是从实践上看,例如moss要做到和人类比较好的对齐,光微调的数据就达到100w的级别,这个级别的高质量数据收集起来代价还是比较高的,而后面RL的步骤,从实践结果来看,它能够用少量的数据让模型在对齐上的效果和泛化性达到一个新的高度。

从这个文章Awesome 论文合集 |不看这些论文,你都不知道 RLHF 是如此的神奇 (4) - OpenDILab浦策的文章 - 知乎 Awesome 论文合集 |不看这些论文,你都不知道 RLHF 是如此的神奇 (4) - 知乎看,RLHF有这三个优点:

  • 建立优化范式:为无法显式定义奖励函数的决策任务,建立新的优化范式。对于需要人类偏好指引的机器学习任务,探索出一条可行且较高效的交互式训练学习方案。
  • 省数据(Data-Efficient):相对其他的训练方法,例如监督学习,Top-K 采样等,RLHF 能够利用更少的人类反馈数据达到相近的训练效果。
  • 省参数(Parameter-Efficient):相对其他的训练方法,例如监督学习,Top-K 采样等,RLHF 可以让参数量较小的神经网络也能发挥出强大的性能。

从符尧大神的文章Notion – The all-in-one workspace for your notes, tasks, wikis, and databases.里可以看出RLHF的效果如下:

  • 翔实的回应: text-davinci-003 的生成通常比 text-davinci-002长。 ChatGPT 的回应则更加冗长,以至于用户必须明确要求“用一句话回答我”,才能得到更加简洁的回答。这是 RLHF 的直接产物。
  • **公正的回应:**ChatGPT 通常对涉及多个实体利益的事件(例如政治事件)给出非常平衡的回答。这也是RLHF的产物。
  • **拒绝不当问题:**这是内容过滤器和由 RLHF 触发的模型自身能力的结合,过滤器过滤掉一部分,然后模型再拒绝一部分。
  • **拒绝其知识范围之外的问题:**例如,拒绝在2021 年 6 月之后发生的新事件(因为它没在这之后的数据上训练过)。这是 RLHF 最神奇的部分,因为它使模型能够隐式地区分哪些问题在其知识范围内,哪些问题不在其知识范围内。

二、实践效果

我们的中文实验大多是基于GLM10B的huggingface版本进行的。SFT和大多网上的策略是一样的,使用开源的指令数据集,和一些ChatgptAPI生成的数据集训练。目前网上没有像英文领域有那么多公开的偏好数据集,早期我们直接用翻译接口翻译了HH-RLHF数据集,然后训练了一个回报模型。之后在一些中文的多轮对话上做强化学习,这样粗糙的RLHF,已经可以得到一个能够生成翔实的回应的PPO模型了。但是,也只是变得翔实而已,遵守指令的能力甚至变弱了,也没有丝毫安全性的提升(因为完全没相关数据)。

后来在清华开源的安全数据集上,经过一些精挑细选,分布到RM和PPO中,模型就可以保持翔实的前提下提高安全性,但是指令的遵循能力还很弱。但是这里也证明了一点,只要数据分布合理,RM和PPO就能让模型得到相关能力的提升。所以要得到一个对指令有广泛理解,答案翔实,安全且诚实,对于RM数据集的要求还是蛮高的,同时PPO应该也有相同的分布。

我们使用的RM的数据集和PPO数据集都没有达到1w级别,这也证实了强化学习算法的泛化性确实很强。

三、怎么做

1、框架

现在RLHF相关的框架非常多,基本上每周都有新的开源框架出现。选择一个合适的框架,一个是方便我们写代码,一个是能够节省更多显存。我们学习使用的框架有DeepspeedChat、Trlx、ColossalAI-Chat,同时也包括一些常用的框架例如Accelerate、PEFT等。每个框架都有自己的优缺点,这里大概说一下:

Trlx:GitHub - CarperAI/trlx: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

优势:应该是目前网上大家提到的,使用最广泛的LLM的强化学习框架了。这个框架里面的算法基本是参考了OpenAI当年LM强化学习开源代码的实现,在此基础上,增加了Accelerate框架的调用支持,还有对各种常见的LM的封装,主要是添加了ValueFunction的head,还有一些冻结参数的支持。

不足:代码逻辑比起其他框架来说,有些凌乱,新手看起来不太友好。我第一个学的就是Trlx,后来看ColossalAI感觉Trlx写的真乱。还有就是Trlx的代码里,默认情况下,离线策略只执行一次,然后就训练,感觉有点奇怪。我实践经验上看,多次迭代效果是更好的。其次就是Trlx里面对Huggingface的模型封装比较复杂,我要在GLM上改挺麻烦的。

补充:trlx默认的参数基本是都是ok的,特别是gamma和lam的值,改了之后效果可能会差很多

  • DeepspeedChat:DeepSpeedExamples/applications/DeepSpeed-Chat at master · microsoft/DeepSpeedExamples

优势:应该是目前最容易能达成100B以上Huggingface模型强化学习的框架了。里面强化学习的部分大多和Trlx的算法一致,添加了PTX损失和EMA算法。代码逻辑也比较清晰。借助最新0.9.0版本deepspeed新增的混合引擎,实现zero3推理时,自动完成张量并行,大大降低了100B基本模型的强化学习门槛。

缺点:lora功能不完善。deepspeed混合引擎目前只支持几个BLOOM、GPT等LM,如果要支持GLM不知道要怎么改。所以暂时没有使用它。

  • ColossalAI-Chat:https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat

优势:代码逻辑清晰,新手学习非常友好。自己实现了一个和Trlx不太一样的PPO算法,每个句子只生成一个reward,没有时间步的概念,自然也就没有基于GAE求解优势函数的算法。这个算法本身我实践下来不太好训练,后面我们自己将其中的value function进行优化后,才成功训练起来,效果还是很不错的。

不足:ColossalAI框架本身不太完善。新功能得等社区慢慢更新,和很多流行的框架也不兼容,比较麻烦。

目前我们采用的是Accelerate+deepspeed的基本框架,同时用PEFT的lora减少显存占用。其实Accelerate和deepspeed的组合也不是特别好,Accelerate里面如果调用deepspeed的话,只支持一个模型和一个优化器,这导致ppo训练的时候比较麻烦,还不如直接使用原生的deepspeed。但是accelerate在分布式训练的时候,确实有它的优势,帮你解决了很多麻烦的事情,代码写起来比较省心。

这里我们特别提一下PEFT新分支中有个对多适配器lora的支持,这个功能天生就和PPO非常的搭,相当于一个基模型,通过挂多个lora的适配器,就可以随时变成RM、Critic、Actor、RefModel。同时加载四个模型,只需要消耗几乎等同于一个模型的显存,非常的香。GLM10B,开启zero2,在PPO训练的时候,单卡开到bs4,最终大概占用了30多G的显存。

补充:lora Multi Adapter功能已经合并到主分支了,详情可以看0.3.0的更新公告。

顺便提一下RLHF里一些好用的显存优化方法:

  • 多lora适配器(不全量训练的PPO神器)
  • deepspeed zero(什么地方都可以用它)
  • gradient checkpointing(显存节省神器)
  • flash attention(也是显存节省神器,LLAMA可以直接用,GLM不知道怎么适配)
  • deepspeed混合引擎(30B以上PPO神器,希望以后能提供如何适配更多模型的文档)
  • BF16(不会有FP16的溢出问题,训练PPO的时候比较安全)

2、算法

我们在Accelerate+deepspeed+peft的基本框架下,参照ColossalAI的代码逻辑,重新实现了一种回报模型算法和三种对齐算法。

2.1 回报模型

回报模型的结构和loss设计基本和Trlx保持一致,分数是取句末token的分数,实践证明这样训练后的权重用来初始化Critic是最有利于训练多时间步的PPO的。ColossalAI的回报模型分数是将句子所有的token求平均,这个如果是训练单步的PPO是没啥区别的,但是训练多步的话就不太合适。所以最后我还是都统一用Trlx的风格。

2.2 对齐算法

对齐算法我们实现了三种,一个是Trlx的多步PPO算法、一个是ColossalAI的单步PPO算法、一个是最近阿里开源的RRHF算法GitHub - GanjinZero/RRHF: RRHF & Wombat。

其中单步的PPO算法,ColossalAI默认是用一个Critic模型去拟合reward,这样训练出来的优势值很小很难训练。其实优势函数的本意就是累积奖励-累积奖励的期望。而对于单步的PPO来说,累积奖励就是单步奖励,而单步奖励的期望,其实并不需要一个神经网络去拟合。我们可以简单的通过随机生成n个答案,将它们的平均reward作为累积奖励的期望就可以训练的。这样即节省了一个神经网络,效果也非常好。

对于RRHF算法,原文中是离线生成了所有训练数据的答案,再去做训练。比较费时,训练起来也比较慢。我们也改成和ColossalAI类似的制作一小批数据就训练一次的方式,这样reward的增长会快一些。

实践下来,Trlx的多步PPO算法、ColossalAI的单步PPO、RRHF它们三者的reward上涨的量都差不多。RRHF上涨会快一些,但是的KL散度偏离要比PPO大很多。不过,RRHF基本不需要调参,PPO需要比较精细的调参。

3、数据

不知道中文什么时候能够有开源的比较完备的偏好数据集,能够涵盖较多的指令场景,同时在真实性、安全性方便也能有所顾及。其实只要有问题就行,答案最好是让sft去生成再找人打标,从instructgpt论文里看,这样ppo阶段的分数才是最精确的。

4、调参

在影响PPO算法性能的10个关键技巧(附PPO算法简洁Pytorch实现) - Beaman的文章 - 知乎 影响PPO算法性能的10个关键技巧(附PPO算法简洁Pytorch实现) - 知乎这篇文章里,提到了很多PPO的优化方法,里面我只试了一部分,目前来看,对优势值的正则化是有效的,能够让actor的loss变得稳定,如果是分布式的场景,记得要同步后再做正则,这块Trlx有相关的实现。Adam Optimizer Epsilon Parameter这个也是有效的,很神奇。对reward和value的正则化我没有试过。然后梯度裁剪、学习率衰减那些我都是有加的,多少都有点用。

目前来看,主要就是每轮到底要制作多少离线的数据,太少模型会学的不太稳定,太多模型会学的太慢,这个需要多做实验尝试。然后就是每批数据要训练多少轮,太少,模型学的慢,太多容易过拟合。不知道deepspeedchat为什么会说他们只制作一次,训练一轮是最好的。这个我这边感觉还是多轮迭代比较好。

希望各位大佬也能分享一下经验,一起学习学习

参考:1、https://www.libhunt.com/compare-DeepSpeed-vs-ColossalAI

           2、https://aicconf.net/pdf/AI%20infra%E6%8A%80%E6%9C%AF%E5%88%9B%E6%96%B0%E8%AE%BA%E5%9D%9B-%E3%80%90%E5%B0%A4%E6%B4%8B%E4%B8%A8%E6%BD%9E%E6%99%A8%E7%A7%91%E6%8A%80%E3%80%91-%E3%80%8AColossal-AI%EF%BC%9AAI%E5%A4%A7%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8C%91%E6%88%98%E4%B8%8E%E7%B3%BB%E7%BB%9F%E4%BC%98%E5%8C%96%E3%80%8B.pdf

        3、https://hpc-ai.com/benchmarks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文详解动态 Schema

在数据库中,Schema 常有,而动态 Schema 不常有。 例如,SQL 数据库有预定义的 Schema,但这些 Schema 通常都不能修改,用户只有在创建时才能定义 Schema。Schema 的作用是告诉数据库使用者所希望的表结构,确保…

网络安全红队常用的攻击方法及路径

一、信息收集 收集的内容包括目标系统的组织架构、IT资产、敏感信息泄露、供应商信息等各个方面,通过对收集的信息进行梳理,定位到安全薄弱点,从而实施下一步的攻击行为。 域名收集 1.备案查询 天眼查爱企查官方ICP备案查询 通过以上三个…

Java BIO、NIO、AIO、Netty知识详解(值得珍藏)

1. 什么是IO Java中I/O是以流为基础进行数据的输入输出的,所有数据被串行化(所谓串行化就是数据要按顺序进行输入输出)写入输出流。简单来说就是java通过io流方式和外部设备进行交互。 在Java类库中,IO部分的内容是很庞大的,因为它涉及的领…

YOLOv5改进 | Neck篇 | 利用Damo-YOLO的RepGFPN改进特征融合层

一、本文介绍 本文给大家带来的改进机制是Damo-YOLO的RepGFPN(重参数化泛化特征金字塔网络),利用其优化YOLOv5的Neck部分,可以在不影响计算量的同时大幅度涨点(亲测在小目标和大目标检测的数据集上效果均表现良好涨点幅度超级高!)。RepGFPN不同于以往提出的改进模块,其…

K8S--- volumesvolumeMount

一、Volume 简介 在容器当中的磁盘文件(on-disk file )是短暂的(ephemeral),这会对重要的应用程序或者数据产生一些问题。当容器崩溃或停止时,会出现一个问题,即容器状态不会被保存,因此在容器生命周期内被创建或者修改的文件都将丢失。在容器崩溃期间,kubelet会以干净状…

【数据库】聊聊常见的索引优化-上

数据库对于现有互联网应用来说,其实是非常重要的后端存储组件,而大多数系统故障都是由于存储所导致的,而数据库是重中之重,所以为了比较好掌握SQL的基本优化手段,打算用两篇文章从基本的联合索引优化、group by/order …

【Web开发】会话管理与无 Cookie 环境下的实现策略

🍎个人博客:个人主页 🏆个人专栏: Web开发 ⛳️ 功不唐捐,玉汝于成 目录 前言 正文 问题: 思路: 方法: 结语 我的其他博客 前言 在当今Web应用程序中,会话…

Go (一) 基础部分5 -- 单元测试,协程(goroutine),管道(channel)

一、单元测试 Go自带一个轻量级的"测试框架testing"和自带的"go test"命令来实现单元测试和性能测试。 1.确保每个函数时可运行,并且运行结果是正确的。 2.确保写出来的代码性能是好的。 3.单元测试能及时的发现程序设计或实现的逻辑错误&#…

程序员副业之无人直播助眠

介绍和概览 大家好,我是小黑,本文给大家介绍一个比较轻松简单的副业,无人直播助眠副业。 这个项目的核心就是通过直播一些助眠素材来赚钱。比如你可以放一些舒缓的雨声之类的,吸引观众进来。然后,咱们可以挂个小程序…

系统运维-Linux SSH密码登录免密登录密钥登陆

SSH:安全外壳协议,是一种在不安全网络上用于安全远程登录和其他安全网络服务的协议 SSH由三部分构成: 1.传输层协议 [SSH-TRANS]: 提供了服务器认证,保密性及完整性。此外它有时还提供压缩功能。 SSH-TRANS 通常运行在…

spring boot 集成邮件发送功能

一、首先到QQ邮箱申请开启POP3、SMTP协议 二、安装依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-mail</artifactId></dependency><dependency><groupId>org.springframew…

C++ 单调栈 || 单调栈模版题

给定一个长度为 N 的整数数列&#xff0c;输出每个数左边第一个比它小的数&#xff0c;如果不存在则输出 −1 。 输入格式 第一行包含整数 N &#xff0c;表示数列长度。 第二行包含 N 个整数&#xff0c;表示整数数列。 输出格式 共一行&#xff0c;包含 N 个整数&#…

ROS 传感器—相机的介绍

在ROS中&#xff0c;相机是一种常见的传感器设备&#xff0c;用于获取视觉信息。ROS支持多种类型的相机&#xff0c;并提供了统一的接口和工具来处理相机数据&#xff0c;使得开发者可以方便地在不同硬件平台上实现视觉功能。 在ROS中&#xff0c;可以通过usb_cam、camera_dri…

探索生成式AI:自动化、问题解决与创新力

目录 自动化和效率&#xff1a;生成式AI的颠覆力量 解谜大师生成式AI&#xff1a;如何理解和解决问题 创新与创造力的启迪&#xff1a;生成式AI的无限潜能 自动化和效率&#xff1a;生成式AI的颠覆力量 1. 神奇的代码生成器&#xff1a;生成式AI可以帮助开发人员像魔术一样快…

TemporalKit的纯手动安装

最近在用本地SD安装temporalkit插件 本地安装插件最常见的问题就是&#xff0c;GitCommandError:… 原因就是&#xff0c;没有科学上网&#xff0c;而且即使搭了ladder&#xff0c;在SD的“从网址上安装”或是“插件安装”都不行&#xff0c;都不行&#xff01;&#xff01;&am…

【JAVA】OPENGL+TIFF格式图片,不同阈值旋转效果

有些科学研究领域会用到一些TIFF格式图片&#xff0c;由于是多张图片相互渐变&#xff0c;看起来比较有意思&#xff1a; import java.io.IOException; import java.text.SimpleDateFormat; import java.util.Date; import java.util.logging.*;/*** 可以自已定义日志打印格式…

窗体控件(表格和控制器)

DataGridView 控件 DataGridView控件是C#中的一个Windows Forms控件&#xff0c;用于在应用程序中显示和编辑表格形式的数据。 先拖出四个label控件和四个TextBox控件和一个ComboBox和一个Button按钮&#xff0c;下面是一个DataGridView控件 准备一个Student类 namespace _窗…

今日学习的 jdbc statement的增删改

首先要获取jdbc文件 Class.forName("com.mysql.jdbc.Driver"); 连接数据库&#xff08;数据库要提前打完在写增删改查&#xff09; Connection connection DriverManager.getConnection("jdbc:mysql://localhost:3306/db_day11","root","…

八大算法排序@堆排序(C语言版本)

目录 堆排序大堆排序概念算法思想建堆建堆核心算法建堆的代码 排序代码实现 小堆排序代码实现时间复杂度空间复杂度 特性总结 堆排序 堆排序借用的是堆的特性来实现排序功能的。大堆需要满足父节点大于子节点&#xff0c;因此堆顶是整个数组中的最大元素。小堆则相反&#xff0…