llama-factory训练RLHF-PPO模型

理论上RLHF(强化学习)效果比sft好,也更难训练。ppo有采用阶段,步骤比较多,训练速度很慢.
记录下工作中使用llama-factory调试rlhf-ppo算法流程及参数配置,希望对大家有所帮助.

llama-factory版本: 0.8.2

一 rlhf流程

ppo训练流程图如下, 会用到多个模型, 但初始化阶段, 只需提供sft和reward模型就行.
在这里插入图片描述

四个子模型用途:

  • Actor Model:演员模型,这就是我们想要训练的目标语言模型
  • Reference Model:参考模型,它的作用是在RLHF阶段给语言模型增加一些“约束”,防止语言模型训歪。我们希望训练出来的Actor模型既能达到符合人类喜好的目的,又尽量让它和SFT模型不要差异太大。即希望两个模型的输出分布尽量相似,通过与Actor Model之间的KL散度控制。
  • Critic Model:评论家模型,它的作用是预估总收益V->(t),在RLHF中,我们不仅要训练模型生成符合人类喜好的内容的能力(Actor),也要提升模型对人类喜好量化判断的能力(Critic)。这就是Critic模型存在的意义。
  • Reward Model:奖励模型,它的作用是计算即时收益R->(t) Actor/Critic Model. 在RLHF阶段是需要训练的;而Reward/Reference Model是参数冻结的。

整体算法流程如下:

  1. 训练sft模型

  2. 训练reward奖励模型

  3. 以sft模型初始化Reference和Actor模型,以奖励模型初始化Critic模型。其中,Actor与Critic模型权重可训练,Reference与Reward冻结权重,全程不更新

  4. rlhf-ppo执行过程分析(对应上图的step 3):
    在这里插入图片描述

  • 第一步,我们准备一个batch的prompts

  • 第二步,我们将这个batch的prompts喂给Actor模型,让它生成对应的responses

  • 第三步,我们把prompt+responses喂给我们的Critic/Reward/Reference模型,让它生成用于计算actor/critic loss的数据,按照强化学习的术语,我们称这些数据为经验(experiences)。

  • 第四步,我们根据这些经验,实际计算出actor/critic loss,然后更新Actor和Critic模型。

    涉及的损失函数:
    在这里插入图片描述

至此, 我们对RLHF-PPO工作原理已经有了清晰的认知. 若觉得上述文字不过瘾, 可以看我列出的几篇参考文献, 对ppo复杂的原理有深刻解读.

二 代码实践

2.1 数据准备: 需要准备sft指令微调和reward奖励模型的数据.

  • sft数据格式
[{"instruction": "<question>:查看备案有效期在今天之后的委托信息\nCREATE TABLE 委托备案信息 (序号 FLOAT,\n委托企业名称 VARCHAR(255),\n公司地址 VARCHAR(255),\n被委托企业名称 VARCHAR(255),\n委托产品名称 VARCHAR(255),\n备案日期 VARCHAR(255),\n备案有效期 VARCHAR(255));","input": "","output": "\nSELECT * FROM 委托备案信息 WHERE TO_DATE(备案有效期, 'YYYY-MM-DD') > NOW();\n"},{"instruction": "<question>:哪些镇名拥有重点旅游村?\nCREATE TABLE 镇名休闲农业园区休闲农庄重点旅游村 (镇名 VARCHAR(255),\n休闲农业园区 VARCHAR(255),\n休闲农庄 VARCHAR(255),\n重点旅游村 VARCHAR(255));","input": "","output": "\nSELECT DISTINCT 镇名 FROM 镇名休闲农业园区休闲农庄重点旅游村 WHERE 重点旅游村 IS NOT NULL;\n"},...
]
  • reward数据格式固定,不能随意更改, 经过断点调试发现, 所有模型的reward数据都遵循以下格式, 其中chosen期望偏好, rejected是负向偏好.
[{"conversations": [{"from": "human","value": "<question>:查看备案有效期在今天之后的委托信息\nCREATE TABLE 委托备案信息 (序号 FLOAT,\n委托企业名称 VARCHAR(255),\n公司地址 VARCHAR(255),\n被委托企业名称 VARCHAR(255),\n委托产品名称 VARCHAR(255),\n备案日期 VARCHAR(255),\n备案有效期 VARCHAR(255));"}],"chosen": {"from": "gpt","value": "\nSELECT * FROM 委托备案信息 WHERE TO_DATE(备案有效期, 'YYYY-MM-DD') > NOW();\n"},"rejected": {"from": "gpt","value": "SELECT * FROM 委托备案信息 WHERE 备案有效期 > NOW()"}},{"conversations": [{"from": "human","value": "<question>:哪些镇名拥有重点旅游村?\nCREATE TABLE 镇名休闲农业园区休闲农庄重点旅游村 (镇名 VARCHAR(255),\n休闲农业园区 VARCHAR(255),\n休闲农庄 VARCHAR(255),\n重点旅游村 VARCHAR(255));"}],"chosen": {"from": "gpt","value": "\nSELECT DISTINCT 镇名 FROM 镇名休闲农业园区休闲农庄重点旅游村 WHERE 重点旅游村 IS NOT NULL;\n"},"rejected": {"from": "gpt","value": "SELECT DISTINCT 镇名 FROM PG库 WHERE 重点旅游村 IS NOT NULL;"}},...
]

2.2 训练代码

新版llama-factory不再使用shell脚本传参, 而是通过yaml文件完成, 之后通过以下代码, 根据传入yaml文件不同执行对应的训练任务.

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from src.llamafactory.train.tuner import run_exp
import yamldef main(yaml_path_):with open(yaml_path_, 'r', encoding='utf-8') as f:param = yaml.safe_load(f)run_exp(param)if __name__ == "__main__":#1.sft指令微调# yaml_path = '../examples/yblir_configs/qwen2_lora_sft.yaml'# 2.奖励模型训练# yaml_path = '../examples/yblir_configs/qwen2_lora_reward.yaml'# 3.rlhf-ppo训练yaml_path = '../examples/yblir_configs/qwen2_lora_ppo.yaml'main(yaml_path)

sft 超参: qwen2_lora_sft.yaml

# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b
#model_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_05b
# method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all# dataset
dataset: train_clean
dataset_dir: ../data
template: qwen
cutoff_len: 1024
#max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 2# output
output_dir: E:\PyCharm\PreTrainModel\qwen2_7b_sft
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true# train
per_device_train_batch_size: 4
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1
fp16: true# eval
val_size: 0.1
per_device_eval_batch_size: 4
evaluation_strategy: steps
eval_steps: 100

sft训练效果:
在这里插入图片描述

rm模型训练参数: qwen2_lora_reward.yaml

# 训练奖励模型
### model
model_name_or_path: /mnt/e/PyCharm/PreTrainModel/qwen2_7b### method
stage: rm
do_train: true
finetuning_type: lora
lora_target: all### dataset
dataset: rw_data
dataset_dir: ../data
template: qwen
cutoff_len: 1024
max_samples: 3000
overwrite_cache: true
preprocessing_num_workers: 1### output
output_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_rm
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000### eval
val_size: 0.1
per_device_eval_batch_size: 2
eval_strategy: steps
eval_steps: 500

rm训练效果:

***** eval metrics *****epoch                   =        3.0eval_accuracy           =        1.0eval_loss               =        0.0eval_runtime            = 0:00:16.73eval_samples_per_second =     17.923eval_steps_per_second   =      8.961
[INFO|modelcard.py:450] 2024-06-26 23:02:36,246 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}, 'metrics': [{'name': 'Accuracy', 'type': 'accuracy', 'value': 1.0}]}

在这里插入图片描述

sft训练完成后,要先merge才能进行下一步ppo训练.
merge代码及配置文件:

# -*- coding: utf-8 -*-
# @Time    : 2024/5/17 23:21
# @Author  : yblir
# @File    : lyb_merge_model.py
# explain  :
# =======================================================
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))import yamlfrom src.llamafactory.train.tuner import export_modelif __name__ == "__main__":with open('../examples/yblir_configs/qwen2_lora_sft_merge.yaml', 'r', encoding='utf-8') as f:param = yaml.safe_load(f)export_model(param)

qwen2_lora_sft_merge.yaml

# Note: DO NOT use quantized model or quantization_bit when merging lora adapters# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b
adapter_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sft
#model_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_05b
#adapter_name_or_path: /media/xk/D6B8A862B8A8433B/data/qwen2_15b_rw
template: qwen
finetuning_type: lora# export
export_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sft_merge
export_size: 2
export_device: cpu
# 为true,保存为safetensors格式
export_legacy_format: true

ppo训练: 使用merge后的sft模型. reward_model参数是rm训练的lora参数, 这样做的好处是节约显存, 不然24G显存根本没法训练7B大小的模型. 而弊端就是, 四个子模型的基座是同一个模型. 只有全量的full训练才能选择不同的模型. 目前看, 都用同一个模型也没发现什么问题.

ppo涉及数据采样, 训练很慢, 4090显卡, 对于以下参数, 显存占用约18G, 耗时约4.5小时才训练完.

### model
model_name_or_path: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sft_merge
reward_model: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_rm### method
stage: ppo
do_train: true
finetuning_type: lora
lora_target: all### dataset
# dataset: identity,alpaca_en_demo
dataset: train_clean
dataset_dir: ../data
template: qwen
cutoff_len: 1024
max_samples: 2000
overwrite_cache: true
preprocessing_num_workers: 1### output
output_dir: /mnt/e/PyCharm/PreTrainModel/qwen2_7b_sql_ppo_1_batch
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true
ddp_timeout: 180000000### generate
max_new_tokens: 512
top_k: 0
top_p: 0.9

ppo训练效果
在这里插入图片描述

ppo训练后进行推理, 使用merge后的sft模型进行的ppo的推理的基座模型, ppo训练的finetuning_type是lora, 因此最终保存的也是lora参数,

lyb_qwen_sft_predict.yaml

# model
model_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sft_merge
adapter_name_or_path: E:\PyCharm\PreTrainModel\qwen2_7b_sql_ppo_1_batchstage: sft
finetuning_type: lora
#lora_target: all
#quantization_bit: 8#infer_backend: vllm# dataset
template: qwen
#cutoff_len: 1024

一个简单的推理代码, 注意模型的输入数据, 与ppo训练时入参格式一样, 本文ppo训练使用的数据与sft是同一份.

# -*- coding: utf-8 -*-
# @Time    : 2024/6/16 20:50
# @Author  : yblir
# @File    : lyb_lora_inference.py
# explain  : 
# =======================================================
import yaml
import json
from loguru import logger
import time
import sys
from src.llamafactory.chat import ChatModelif __name__ == '__main__':with open('../examples/yblir_configs/lyb_qwen_sft_predict.yaml', 'r', encoding='utf-8') as f:param = yaml.safe_load(f)chat_model = ChatModel(param)with open('../data/tuning_sample.json', 'r', encoding='utf-8') as f:data = json.load(f)# 预热messages = [{"role": "user", "content": data[0]['instruction']}]_ = chat_model.chat(messages)predict_1000 = []total_time = 0for i, item in enumerate(data):messages = [{"role": "user", "content": item['instruction']}]t1 = time.time()res = chat_model.chat(messages)total_time += time.time() - t1predict_1000.append(res[0].response_text)#print('-------------------------------------------------')print(i,'->',res[0].response_text)# sys.exit()if (i + 1) % 10 == 0:# logger.info(f'当前完成: {i + 1}')sys.exit()if i + 1 == 300:break# json_data = json.dumps(predict_1000, indent=4, ensure_ascii=False)# with open('saves2/qwen_7b_chat_lora_merge_vllm.json', 'w', encoding='utf-8') as f:#     f.write(json_data)logger.success(f'写入完成, 总耗时:{total_time},平均耗时: {round((total_time / 300), 5)} s')

sft与PPO部分推理结果比较, 具体指标要把sql放到数据库去跑一遍才知道, 结果在公司内网, 不再此列出了.
在这里插入图片描述

三 总结

除了ppo, dpo(Direct Preference Optimization:直接偏好优化)也是一种常见的调优手段, 不过多篇paper研究证明性能不如PPO, 在计算资源不足的情况下DPO也是个不过的选择,因为不需要训练奖励模型, 而且训练速度快,效果也比较稳定, 不像PPO那样很容易训崩.
其他LLM偏好对齐训练技术还有ORPO,IPO,CPO以及效果看起来很棒的KTO.
还有最新发表的RLOO,看起来比PPO更好更易训练.
在这里插入图片描述

这个领域发展太快, 脑子快不够用了.
在这里插入图片描述

四 参考文献

https://blog.csdn.net/sinat_37574187/article/details/138200789
https://blog.csdn.net/2301_78285120/article/details/134888984
https://blog.csdn.net/qq_27590277/article/details/132614226
https://blog.csdn.net/qq_35812205/article/details/133563158

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/39538.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Kubernetes】加入节点Node及问题

命令 分别再node节点机器上&#xff0c;执行如下命令&#xff1a; kubeadm join [master机器ip:端口] --token [master机器初始化生成的token] --discovery-token-ca-cent-hash [master机器初始化生成的hash]问题 由于清屏没有记住token和hash的时候&#xff1a; 1&#xff…

Log4j日志框架讲解(全面,详细)

Log4j概述 Log4j是Apache下的一款开源的日志框架&#xff0c;通过在项目中使用 Log4J&#xff0c;我们可以控制日志信息输出到控制台、文件、甚至是数据库中。我们可以控制每一条日志的输出格式&#xff0c;通过定义日志的输出级别&#xff0c;可以 更灵活的控制日志的输出过程…

如何指定Microsoft Print To PDF的输出路径

在上一篇文章中&#xff0c;介绍了三种将文件转换为PDF的方式。默认情况下&#xff0c;在Microsoft Print To PDF的首选项里&#xff0c;是看不到输出路径的设置的。 需要一点小小的手段。 运行输入 control 打开控制面板&#xff0c;选择硬件和声音下的查看设备和打印机 找到…

【ubuntu18.04】 局域网唤醒 wakeonlan

ai服务器经常因为断电,无法重启,当然可以设置bios 来电启动。 这里使用局域网唤醒配置。 自动开关机设置 工具:ethtool 端口 : enp4s0 Wake-on: d 表示禁用Wake-on: g 激活 ,例如:ethtool -s eth0 wol g 配置/etc/rc.local ,这个文件不存在,自己创建工具下载 tengxun W…

【前端vue3】TypeScrip-类型推论和类型别名

类型推论 TypeScript里&#xff0c;在有些没有明确指出类型的地方&#xff0c;类型推论会帮助提供类型。 例如&#xff1a; 变量xiaoc被推断类型为string 如重新给xiaoc赋值数字会报错 let xiaoc "xiaoc"xiaoc 1111111111111如没有给变量指定类型和赋值&#xf…

专题七:Spring源码之BeanDefinition

上一篇我们通过refresh方法中的第二个核心方法obtainBeanFactory&#xff0c;通过createBeanFacotry创建容Spring的初级容器&#xff0c;并定义了容器的两个核心参数是否允许循环引用和覆盖。现在容器有了&#xff0c;我们来看看容器里的第一个重要成员BeanDefinition。 进入lo…

浙大版PTA《Python 程序设计》题目集 参考答案

浙大版PTA《Python 程序设计》题目集 参考答案 本答案配套详解教程专栏&#xff0c;欢迎订阅&#xff1a; PTA浙大版《Python 程序设计》题目集 详解教程_少侠PSY的博客-CSDN博客 01第1章-1 从键盘输入两个数&#xff0c;求它们的和并输出 aint(input()) # 输入a的值 bint(…

从需求是如何最终抽象成最基本的传参入参

第一层&#xff1a;出参和入参 用通俗的话讲&#xff0c;就是给客户提供服务的一种方式&#xff0c;需要包含入参和出参 。入口参数就是程序执行时会调用的参数&#xff0c;出口参数就是程序执行完会返回的参数。入参的值是被调函数需要&#xff0c; 出参的值是主调函数需要的…

【文件上传】

文件上传漏洞 FileUpload 0x01 定义 服务端未对客户端上传文件进行严格的 验证和过滤造成可上传任意文件情况&#xff1b;0x02 攻击满足条件&#xff1a; 1. 上传文件能够被Web容器解释执行   2. 找到文件位置   3.上传文件未被改变内容。&#xff08;躲避安全检查&#…

【Linux系统】CUDA的安装与graspnet环境配置遇到的问题

今天在安装环境时遇到报错&#xff1a; The detected CUDA version (10.1) mismatches the version that was used to compile PyTorch (11.8). Please make sure to use the same CUDA versions. 报错原因&#xff1a;安装的cuda版本不对应&#xff0c;我需要安装cuda的版本…

Spark面试题总结

一、RDD的五大特性是什么 1、RDD是由一些分区构成的&#xff0c;读取文件时有多少个block块&#xff0c;RDD中就会有多少个分区 2、算子实际上是作用在RDD中的分区上的&#xff0c;一个分区是由一个task处理&#xff0c;有多少个分区&#xff0c;总共就有多少个task 3、RDD之间…

windows远程连接无法复制文件

windows远程桌面无法复制文件 解决方案 打开任务管理器管理器,在详细信息界面,找到rdpclip.exe进程&#xff0c;选中并点击结束任务&#xff0c;杀死该进程。 快捷键 win r 打开运行界面&#xff0c;输入 rdpclip.exe &#xff0c;点击确定运行。即可解决无法复制文件问题。…

WebDriver 类的常用属性和方法

目录 &#x1f38d;简介 &#x1f38a;WebDriver 核心概念 &#x1f389;WebDriver 常用属性 &#x1f381;WebDriver 常用方法 &#x1f437;示例代码 &#x1f3aa;注意事项 &#x1f390;结语 &#x1f9e3;参考资料 &#x1f38d;简介 Selenium WebDriver 是一个用…

产品设计的8大步骤

产品设计&#xff0c;通俗来说就是将创新想法或概念转化为落地实体的过程。一般来说&#xff0c;一个成功的产品应当具有创新性、美观性、实用性、可持续性以及经济效益&#xff0c;从而满足用户的使用需求以及市场的发展需求。产品设计也并不是一件简单的事情&#xff0c;产品…

Docker与微服务实战2022 尚

Docker与微服务实战2022 尚硅谷讲师:周阳 1. 基础篇(零基小白) 1 1.1. Docker简介 2 1.2. Docker安装 15 1.3. Docker常用命令 29 1.4. Docker镜像 43 1.5. 本地镜像发布到阿里云 50 1.6. 本地镜像发布到私有库 57 1.7. Docker容器数据卷 64 1.8. Docker常规安装简介 …

firewalld开放端口常用命令

在Linux系统中&#xff0c;常使用firewalld服务来管理防火墙&#xff0c;可以通过命令行来开放特定的端口。 查firewalld运行状态&#xff1a; sudo systemctl status firewalld 确保firewalld正在运行&#xff0c;可以使用以下命令来启动并使其在系统启动时自动运行&#xff1…

经典的卷积神经网络模型 - AlexNet

经典的卷积神经网络模型 - AlexNet flyfish AlexNet 是由 Alex Krizhevsky、Ilya Sutskever 和 Geoffrey Hinton 在 2012 年提出的一个深度卷积神经网络模型&#xff0c;在 ILSVRC-2012&#xff08;ImageNet Large Scale Visual Recognition Challenge 2012&#xff09;竞赛中…

劳务工程元宇宙的探索与实践

随着元宇宙概念的不断深入&#xff0c;各行各业都在探索与这一新兴技术结合的可能性。劳务工程行业也未落后&#xff0c;开始思考和实验如何将元宇宙的概念与劳务工程相结合&#xff0c;以期提高效率、降低成本&#xff0c;同时创造更多价值。本文将探讨劳务工程元宇宙的现状、…

242. 有效的字母异位词【哈希表】【C++】

题目描述 有效的字母异位词 给定两个字符串 s 和 t &#xff0c;编写一个函数来判断 t 是否是 s 的字母异位词。 注意&#xff1a;若 s 和 t 中每个字符出现的次数都相同&#xff0c;则称 s 和 t 互为字母异位词。 示例 1: 输入: s “anagram”, t “nagaram” 输出: true 示…

公司法下的公司注册资金实缴的建议

公司法下的公司注册资金实缴的建议 新公司法已经实施了&#xff0c;现在设立的公司都将要按照新公司法的规定来执行。 那么新公司法对企业最大的影响&#xff0c;就是我们目前热议的公司实缴问题。 公司实缴这个问题我以前讲过好几次。针对近期看到的消息来说下我个人的观点。…