使用 ORPO 微调 Llama 3

原文地址:https://towardsdatascience.com/fine-tune-llama-3-with-orpo-56cfab2f9ada

更便宜、更快的统一微调技术

2024 年 4 月 19 日

ORPO 是一种新的令人兴奋的微调技术,它将传统的监督微调和偏好校准阶段合并为一个过程。这减少了训练所需的计算资源和时间。此外,经验结果表明,在各种模型大小和基准上,ORPO 都优于其他配准方法。

在本文中,我们将使用 ORPO 和 TRL 库对新的 Llama 3 8B 模型进行微调。

ORPO

指令调整和偏好对齐是使大型语言模型(LLM)适应特定任务的基本技术。传统上,这涉及一个多阶段过程:1/ 对指令进行监督微调 (SFT),使模型适应目标领域;2/偏好调整方法,如人工反馈强化学习 (RLHF) 或直接偏好优化 (DPO),以提高生成首选响应而非拒绝响应的可能性。

12

不过,研究人员也发现了这种方法的局限性。虽然 SFT 能有效地使模型适应所需的领域,但却无意中增加了在生成首选答案的同时生成不想要的答案的概率。这就是为什么有必要进行偏好调整阶段,以拉大首选输出和拒绝输出的可能性之间的差距。

13

由 Hong 和 Lee(2024 年)提出的 ORPO 将指令调整和偏好调整结合到一个单一的、整体的训练过程中,为这一问题提供了一个优雅的解决方案。ORPO 修改了标准语言建模目标,将负对数似然损失与几率比(OR)项相结合。这种赔率损失会对被拒绝的反应进行弱惩罚,同时对偏好的反应进行强奖励,从而使模型能够同时学习目标任务并与人类偏好保持一致。

14

使用 ORPO 微调 Llama 3

Llama 3 是 Meta 开发的最新 LLM 系列。这些模型是在一个包含 15 万亿个词库(相比之下,Llama 2 包含 2T 个词库)的广泛数据集上训练的。目前已发布两种规模的模型:700 亿参数模型和较小的 80 亿参数模型。70B 模型已经表现出令人印象深刻的性能,在 MMLU 基准测试中获得 82 分,在 HumanEval 基准测试中获得 81.7 分。

Llama 3 模型还将上下文长度增加到 8,192 个标记(Llama 2 为 4,096 个标记),并有可能通过 RoPE 扩展到 32k。此外,这些模型还使用了具有 128K 标记词汇的新标记化器,从而将文本编码所需的标记数量减少了 15%。这个词汇量也是参数从 7B 增加到 8B 的原因。

15

ORPO 需要一个偏好数据集,包括提示、选择的答案和拒绝的答案。

按照惯例,我们先安装所需的库:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们就可以导入必要的库并登录 W&B(可选):

import gc
import os
import torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,TrainingArguments,pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

如果你使用的是最新的 GPU,你还可以使用 Flash Attention 库,以更高效的方式替换默认的急切关注实现。

if torch.cuda.get_device_capability()[0] >= 8:!pip install -qqq flash-attnattn_implementation = "flash_attention_2"torch_dtype = torch.bfloat16
else:attn_implementation = "eager"torch_dtype = torch.float16

在下文中,我们将利用 bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 QLoRA 的 PEFT 设置 LoRA 配置。我还使用方便的 setup_chat_format() 函数来修改模型和标记符,以支持 ChatML。它会自动应用聊天模板,添加特殊标记,并调整模型嵌入层的大小,以匹配新的词汇量大小。

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"
# QLoRA config
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch_dtype,bnb_4bit_use_double_quant=True,
)
# LoRA config
peft_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
# Load model
model = AutoModelForCausalLM.from_pretrained(base_model,quantization_config=bnb_config,device_map="auto",attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

既然模型已经准备好进行训练,我们就可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将 "选择 "和 "拒绝 "列转换为 ChatML 格式。请注意,我只使用了 1,000 个样本,而不是整个数据集,因为这样运行时间太长。

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(10))
def format_chat_template(row):row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)return row
dataset = dataset.map(format_chat_template,num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置几个超参数:

  • 学习率: 与传统的 SFT 甚至 DPO 相比,ORPO 使用非常低的学习率。这个 8e-6 的值来自原始论文,大致相当于 1e-5 的 SFT 学习率和 5e-6 的 DPO 学习率。我建议将其提高到 1e-6 左右,以进行真正的微调。
  • beta: 它是论文中的 $\lambda$ 参数,默认值为 0.1。原始论文的附录展示了如何通过烧蚀研究来选择它。
  • 其他参数,如 max_length 和批量大小,都设置为使用尽可能多的可用 VRAM(在此配置中约为 20 GB)。理想情况下,我们会对模型进行 3-5 个历元的训练,但这里我们会坚持使用 1 个历元。

最后,我们可以使用 ORPOTrainer 对模型进行训练。

orpo_args = ORPOConfig(learning_rate=8e-6,beta=0.1,lr_scheduler_type="linear",max_length=1024,max_prompt_length=512,per_device_train_batch_size=2,per_device_eval_batch_size=2,gradient_accumulation_steps=4,optim="paged_adamw_8bit",num_train_epochs=1,evaluation_strategy="steps",eval_steps=0.2,logging_steps=1,warmup_steps=10,report_to="wandb",output_dir="./results/",
)
trainer = ORPOTrainer(model=model,args=orpo_args,train_dataset=dataset["train"],eval_dataset=dataset["test"],peft_config=peft_config,tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

在 L4 GPU 上对这 1,000 个样本进行模型训练耗时约 2 小时。我们来看看 W&B 图:

16

虽然损失减少了,但选择答案和拒绝答案之间的差异并不明显:平均余量和准确率分别仅略高于零和 0.5。

在原始论文中,作者在 Anthropic/hh-rlhf 数据集(161k 个样本)上训练了 10 个历时的模型,这比我们快速运行的时间要长得多。他们还使用 Llama 3 进行了实验,并慷慨地分享了他们的日志。

在本文的最后,让我们将 QLoRA 适配器与基础模型合并,并将其推送到Hugging Face Hub。

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()
# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model,low_cpu_mem_usage=True,return_dict=True,torch_dtype=torch.float16,device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)
# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

现在,我们完成了对 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B。你可以使用这个 "Hugging Face Space"来体验一下。正如 W&B 曲线所示,虽然模型训练不足,但我还是使用 LLM AutoEval 在 Nous 的基准套件上进行了一些评估。

17

我们的 ORPO 微调实际上非常不错,在每个基准测试中都提高了基本型号的性能。这是令人鼓舞的,很可能意味着对整个 40k 样本进行微调会产生很好的结果。

对于开源社区来说,这是一个激动人心的时刻,越来越多的高质量开放重量模型被发布。闭源模型和开放重量模型之间的差距正在慢慢缩小,而微调是为你的使用案例获得最佳性能的重要工具。

18

结论

在本文中,我们介绍了 ORPO 算法,并解释了它如何将 SFT 和偏好校准阶段统一为一个过程。然后,我们使用 TRL 在自定义偏好数据集上对 Llama 3 8B 模型进行了微调。最终的模型显示了令人鼓舞的结果,并凸显了 ORPO 作为一种新的微调范例的潜力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/6308.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】第二门课 改善深层神经网络 Week 2 3 优化算法、超参数调试和BN及其框架

🚀Write In Front🚀 📝个人主页:令夏二十三 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝 📣系列专栏:深度学习 💬总结:希望你看完之后,能对…

python实现验证码-图片类型

1 utils.py import randomdef get_random_code():code for i in range(5):# 随机生成大写字母upper_char chr(random.randint(65, 90))lower_char chr(random.randint(97, 122))num_char str(random.randint(0, 9))res random.choice([upper_char, lower_char, num_char]…

软件应用开发安全设计指南

1.1 应用系统架构安全设计要求 设计时要充分考虑到系统架构的稳固性、可维护性和可扩展性,以确保系统在面对各种安全威胁时能够稳定运行。 在设计系统架构时,要充分考虑各种安全威胁,如DDoS攻击、SQL注入、跨站脚本攻击(XSS&…

如何使用 Node.js 开发一个文件上传功能?

在 Node.js 中实现文件上传功能可以通过多种方式完成,但其中最常用的方法之一是使用 Express 框架和 Multer 中间件。Express 是一个流行的 Node.js Web 框架,而 Multer 是一个用于处理文件上传的中间件。 步骤 1: 准备工作 首先,确保你已经…

《Fundamentals of Power Electronics》——升压隔离型变换器、SEPIC隔离型变换器

以下是升压型隔离变换器的相关知识点: 升压型隔离变换器可以通过互换降压型隔离变换器的电源与负载的位置得到。升压型隔离变换器有许多种结构,此处简短的讨论两种情况。这些转换器主要使用在高压电源和低谐波整流器中。 图6.36所示是一种全桥型电路结…

企业定制AI智能名片商城小程序:重塑营销场景,引领数字化营销新纪元

在数字化时代的浪潮中,多企业AI智能名片商城小程序以其独特的魅力和创新的功能,为消费者带来了前所未有的购物体验。它不仅是一个汇聚各类商品的购物平台,更是一个充满活力和创造力的社群生态。通过强化社群互动、鼓励用户生成内容以及引入积…

uniapp 文字转语音(文字播报、语音合成)、震动提示插件 Ba-TTS

简介(下载地址) Ba-TTS 是一款uniapp语音合成(tts)插件,支持文本转语音(无服务费),支持震动提示。 支持语音合成,文本转语音支持震动(可自定义任意震动效果…

一对一WebRTC视频通话系列(二)——websocket和join信令实现

本系列博客主要记录WebRtc实现过程中的一些重点,代码全部进行了注释,便于理解WebRTC整体实现。 一对一WebRTC视频通话系列往期博客: 一对一WebRTC视频通话系列(一)—— 创建页面并显示摄像头画面 websocket和join信令…

Go实战训练之Web Server 与路由树

Server & 路由树 Server Web 核心 对于一个 Web 框架,至少要提供三个抽象: Server:代表服务器的抽象Context:表示上下文的抽象路由树 Server 从特性上来说,至少要提供三部分功能: 生命周期控制&…

堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41),framework修改,去除第三方app的倒计时页面

文章目录 堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41),framework修改,去除第三方app的倒计时页面1.打印异常堆栈2.去除第三方app的倒计时页面3.模拟点击事件跳过首页进入主页 堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41)&#x…

领域驱动设计(DDD)笔记(三)后端工程架构

文章链接 领域驱动设计(DDD)笔记(一)基本概念-CSDN博客领域驱动设计(DDD)笔记(二)代码组织原则-CSDN博客领域驱动设计(DDD)笔记(三)后端工程架构-CSDN博客前导 领域驱动设计(Domain Driven Design,简称DDD)是业内主导的业务工程理论。它在各中权威人士被广泛讨论…

华为云耀云服务器开放端口

博客主页:花果山~程序猿-CSDN博客 关注我一起学习,一起进步,一起探索编程的无限可能吧!让我们一起努力,一起成长! 目录 一.华为云控制台开放端口 寻找到安全组信息 2. 添加开放的端口信息 3. 检查是否成…

信息泄露.

一,遍历目录 目录遍历:没有过滤目录相关的跳转符号(例如:../),我们可以利用这个目录找到服务器中的每一个文件,也就是遍历。 tipe:依次点击文件就可以找到flag 二,phpi…

JavaScript基础(四)

逻辑运算符 && 与 : 多个条件同时满足 ΙΙ 或 : 多个条件满足一个 &#xff01; 非 : 否定某个条件 例: <script> //&多个条件同时满足&#xff0c;才返回true //任意一个为false&#xff0c;就返回false var a 10; var b 20; …

vue快速入门(五十)重定向

注释很详细&#xff0c;直接上代码 上一篇 本篇建立在之前篇目前提下针对重定向进行演示 新增内容 路由重定向写法 源码 src/router/index.js //导入所需模块 import Vue from "vue"; import VueRouter from "vue-router"; import myMusic from "/v…

【51蛋骗鸡595点阵88数码管流水灯综合应用】2021-12-30

缘由51单片机变量进阶与点阵LED-嵌入式-CSDN问答 大佬们 求解单片机点亮点阵程序 被困3天了一直想不明白 - 24小时必答区 #include<reg52.h>//头文件sbit shcpP1^2;//数据输入时钟线 595的11脚 sbit stcpP1^1;//输出存储器锁存时钟线 595的12脚 sbit dsP1^0;//数据线 5…

AI视频教程下载:零代码创建AI智能体、AI Agents和ChatGPT的Gpts

这门课程专注于提示工程的掌握&#xff0c;教你以精确的方式引导GPT&#xff0c;利用它们的生成能力产生卓越的AI驱动结果。一步一步地&#xff0c;你将学会创建多样化的GPT军团——每个都设计来满足特定的专业需求。 从提供个性化职业变更指导的职业教练AI&#xff0c;到以惊…

无人机+飞行汽车:低空经济新引擎,有望爆发式增长

无人机和飞行汽车作为低空经济的新引擎&#xff0c;正在引领一场全新的交通革命。随着技术的不断进步和政策的支持&#xff0c;低空经济有望成为未来经济发展的重要领域&#xff0c;实现爆发式增长。 首先&#xff0c;无人机和飞行汽车具有独特的优势和应用场景。无人机可以在…

Adobe PS 2023、Adobe Photoshop 2023下载教程、安装教程

Adobe Photoshop &#xff08;<-下载连接&#xff09;简介&#xff1a; Adobe Photoshop是一款广泛使用的图像处理软件&#xff0c;由Adobe公司开发。它提供了许多强大的工具和功能&#xff0c;可以用于图像编辑、合成、修饰、设计等各个领域。用户可以使用Photoshop来调整…

Mybatis四种实例化对象方式

代码准备 创建mybatis-config.xml <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE configurationPUBLIC "-//mybatis.org//DTD Config 3.0//EN""http://mybatis.org/dtd/mybatis-3-config.dtd"> <configuration…