【大模型】微调实战—使用 ORPO 微调 Llama 3

ORPO 是一种新颖微调(fine-tuning)技术,它将传统的监督微调(supervised fine-tuning)和偏好对齐(preference alignment)阶段合并为一个过程。这减少了训练所需的计算资源和时间。此外,实证结果表明,ORPO 在各种模型规模和基准测试(benchmarks)上优于其他对齐方法。
在本文中,我们将使用 ORPO 和 TRL 库对新的 Llama 3 8B 模型进行微调。

ORPO

指令微调(instruction tuning)和偏好对齐(preference alignment)是使LLM适应特定任务的基本技术。传统上,这涉及一个多阶段的过程:1/ 在指令上进行监督微调(Supervised Fine-Tuning, SFT),以使模型适应目标领域,然后 2/ 使用偏好对齐方法,如基于人类反馈的强化学习(Reinforcement Learning with Human Feedback, RLHF)或直接偏好优化(Direct Preference Optimization, DPO),以增加生成首选响应而非被拒绝响应的可能性。
在这里插入图片描述

然而,研究人员发现了这种方法的局限性。虽然 SFT 有效地使模型适应所需的领域,但它无意中增加了在首选答案的同时生成不需要的答案的可能性。这就是为什么偏好调整阶段对于扩大首选输出和拒绝输出的可能性之间的差距是必要的。
ORPO 由 Hong 和 Lee (2024) 提出,通过将指令调整和偏好对齐结合到一个单一的整体训练过程中,为这个问题提供了一个优雅的解决方案。 ORPO 修改了标准语言建模目标,将负对数似然损失与优势比 (OR) 项相结合。这种 OR 损失对被拒绝的响应进行弱惩罚,同时对首选响应进行强烈奖励,从而使模型能够同时学习目标任务并与人类偏好保持一致。
在这里插入图片描述
ORPO 已在主要微调库中实现,如 TRL、Axolotl 和 LLaMA-Factory。在下一节中,我们将了解如何与 TRL 一起使用。

使用 ORPO 微调 Llama 3

Llama 3 是Meta开发的最新大型语言模型(LLM)家族。该模型在一个包含15万亿个标记的数据集上进行了训练(相比之下,Llama 2 的训练数据集为2万亿个标记)。目前已经发布了两种模型尺寸:一个是拥有70B参数的模型,另一个是较小的8B参数模型。70B参数的模型已经展示了令人印象深刻的性能,在MMLU基准测试中得分为82,在HumanEval基准测试中得分为81.7。
Llama 3 模型还将上下文长度增加到了8,192个标记(相比之下,Llama 2 为4,096个标记),并且有可能通过RoPE扩展到32k。此外,这些模型使用了一种新的分词器,具有128K标记的词汇量,从而减少了编码文本所需的标记数量15%。这种词汇量的增加也解释了参数从70亿增加到80亿。
ORPO 需要一个偏好数据集,包括提示、选择的答案和拒绝的答案。在此示例中,我们将使用 mlabonne/orpo-dpo-mix-40k ,它是以下高质量 DPO 数据集的组合:

  • argilla/distilabel-capybara-dpo-7k-binarized: highly scored chosen answers >=5 (2,882 samples)
  • argilla/distilabel-intel-orca-dpo-pairs: highly scored chosen answers>=9, not in GSM8K (2,299 samples)
  • argilla/ultrafeedback-binarized-preferences-cleaned: highly scoredchosen answers >=5 (22,799 samples)
  • argilla/distilabel-math-preference-dpo: highly scored chosen answers>=9 (2,181 samples)
  • unalignment/toxic-dpo-v0.2 (541 samples)
  • M4-ai/prm_dpo_pairs_cleaned (7,958 samples)
  • jondurbin/truthy-dpo-v0.1 (1,016 samples)

首先安装所需的库:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们可以导入必要的库并登录W&B(可选)

import gc
import osimport torch
import wandb
from datasets import load_dataset
# from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,TrainingArguments,pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format# wb_token = userdata.get('wandb')
# wandb.login(key=wb_token)

如果您有最新的 GPU,还应该能够使用 Flash Attention 库将默认的 eager Attention 实现替换为更高效的实现。

if torch.cuda.get_device_capability()[0] >= 8:#!pip install -qqq flash-attnattn_implementation = "flash_attention_2"torch_dtype = torch.bfloat16
else:attn_implementation = "eager"torch_dtype = torch.float16

接下来,我们将借助bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 QLoRA 的 PEFT 设置 LoRA 配置。我还使用方便的 setup_chat_format() 函数来修改模型和标记生成器以支持 ChatML。它会自动应用此聊天模板,添加特殊标记,并调整模型嵌入层的大小以匹配新的词汇表大小。
请注意,您需要提交访问 meta-llama/Meta-Llama-3-8B 的请求并登录您的 Hugging Face 帐户。或者,您可以加载模型的非门控副本,例如 NousResearch/Meta–Llama-3-8B。(我选择手动从NousResearch/Meta–Llama-3-8B下载)

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"# QLoRA config
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch_dtype,bnb_4bit_use_double_quant=True,
)# LoRA config
peft_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)# Load model
model = AutoModelForCausalLM.from_pretrained(base_model,quantization_config=bnb_config,device_map="auto",attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

现在模型已准备好进行训练,我们可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将“chosen”和“rejected”列转换为 ChatML 格式。请注意,我仅使用 1,00 个样本,而不是整个数据集,因为运行时间太长。(我选择手动下载)

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(100))def format_chat_template(row):row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)return rowdataset = dataset.map(format_chat_template,num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置一些超参数: * learning_rate :与传统的 SFT 甚至 DPO 相比,ORPO 使用非常低的学习率。 8e-6这个值来自原始论文,大致对应于SFT学习率1e-5和DPO学习率5e-6。我建议将其增加到 1e-6 左右以进行真正的微调。 * beta :即论文中的 𝜆 参数,默认值为0.1。原始论文的附录显示了如何通过消融研究选择它。 * 其他参数,如 max_length 和批量大小设置为使用尽可能多的可用 VRAM(此配置中约为 20 GB)。理想情况下,我们会训练模型 3-5 个 epoch,但这里我们坚持使用 1 个 epoch。
最后,我们可以使用 ORPOTrainer 来训练模型,它充当包装器。

orpo_args = ORPOConfig(learning_rate=8e-6,beta=0.1,lr_scheduler_type="linear",max_length=1024,max_prompt_length=512,per_device_train_batch_size=2,per_device_eval_batch_size=2,gradient_accumulation_steps=4,optim="paged_adamw_8bit",num_train_epochs=1,evaluation_strategy="steps",eval_steps=0.2,logging_steps=1,warmup_steps=10,report_to="wandb",output_dir="./results/",
)trainer = ORPOTrainer(model=model,args=orpo_args,train_dataset=dataset["train"],eval_dataset=dataset["test"],peft_config=peft_config,tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

中间需要选择是否使用W&B,不会使用,我选择不使用
在这里插入图片描述
完成了 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B
在这里插入图片描述

生成目录:
在这里插入图片描述

合并完整模型到本地:

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model,low_cpu_mem_usage=True,return_dict=True,torch_dtype=torch.float16,device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()# Save the merged model and tokenizer to local directory
local_save_directory = "new_model"
model.save_pretrained(local_save_directory)
tokenizer.save_pretrained(local_save_directory)

得到和初始模型一样结构的微调模型;
在这里插入图片描述
完整教程:https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html
本文使用代码对原代码改了一部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用微pe装系统

本文仅作为记录,不作为教程。 今天心血来潮想下点游戏玩玩,一看之前分的200gc盘已经红了,再加上大学之后这个笔记本已经用得很少了,于是打算重装电脑。 参考: 微PE辅助安装_哔哩哔哩_bilibil… 1.下载微pe和win10系统到U盘 我这…

Xilinx zc706 USB电路解析

作者 QQ群:852283276 微信:arm80x86 微信公众号:青儿创客基地 B站:主页 https://space.bilibili.com/208826118 参考 USB OTG检测原理 USB3320 USB_ID为低电平时候,为host模式,USB_ID为悬空(高…

python-23-零基础自学python open()和replace()函数运用

学习内容:《python编程:从入门到实践》第二版练习10-2 知识点: 打开文件,replace()替换文件内容,open(), 练习内容: 练习10-2:C语言学习笔记 可使用方法replace()将字符串中的特定单词都替换为另一个单…

kafka系列之offset超强总结及消费后不提交offset情况的分析总结

概述 每当我们调用Kafka的poll()方法或者使用Spring的KafkaListener(其实底层也是poll()方法)注解消费Kafka消息时,它都会返回之前被写入Kafka的记录,即我们组中的消费者还没有读过的记录。 这意味着我们有一种方法可以跟踪该组消费者读取过的记录。 如前…

6.824/6.5840 的Debugging by Pretty Printing配置

TA的原文在:Debugging by Pretty Printing (josejg.com) 为了在WSL2中配置好打印运行日志,我可是忙活了一下午。可恶的log配置 首先是安装rich库Textualize/rich: Rich is a Python library for rich text and beautiful formatting in the terminal. …

用于视频生成的扩散模型

学习自https://lilianweng.github.io/posts/2024-04-12-diffusion-video/ 文章目录 3D UNet和DiTVDMImagen VideoSora 调整图像模型生成视频Make-A-Video(对视频数据微调)Tune-A-VideoGen-1视频 LDMSVD稳定视频扩散 免训练Text2Video-ZeroControlVideo 参…

需求分析|泳道图 ProcessOn教学

文章目录 1.为什么使用泳道图2.具体例子一、如何绘制确定好泳道中枢的角色在中央基于事实来绘制过程不要纠结美观先画主干处理流程再画分支处理流程一个图表达不完,切分子流程过程数不超25 ,A4纸的幅面处理过程过程用动词短语最后美化并加上序号酌情加上…

后端——全局异常处理

一、老办法try-catch 当我们执行一些错误操作导致程序报错时,程序会捕捉到异常报错,这个异常会存在一个Exception对象里 那我们在spring boot工程开发时,当我们执行一个sql查询时报错了,那就会从最底层的Mapper层捕捉到Exceptio…

Android应用程序调试Logcat的使用

Android的程序调试主要使用Logcat进行,本节主要介绍Logcat的使用。 开启调试模式 使用Android Studio进行程序调试,首先需要连接虚拟Android设备或真实Android设备,设备上需要启用调试功能。 虚拟Android设备默认情况下会启用调试功能。对…

微软清华提出全新预训练范式,指令预训练让8B模型实力暴涨!实力碾压70B模型

现在的大模型训练通常会包括两个阶段: 一是无监督的预训练,即通过因果语言建模预测下一个token生成的概率。该方法无需标注数据,这意味着可以利用大规模的数据学习到语言的通用特征和模式。 二是指令微调,即通过自然语言指令构建…

通过高德地图 JS API实现单击鼠标进行标注

效果图: 核心代码: <template><a-modal title="选择地图所在位置" :width="width" :visible="visible" @ok="handleOk" @cancel="handleCancel" cancelText="关闭"><div class="location-…

场外期权有交割日吗?场外期权应该怎么交割?

今天带你了解场外期权有交割日吗&#xff1f;场外期权应该怎么交割&#xff1f;场外个股期权是一种非标准化的金融衍生品&#xff0c;它允许投资者在未来某一特定日期以特定价格买入或卖出某一特定股票。 交割日就是买卖双方进行交割的日期,期权合约具有到期日,到期日的后一天…

C电池 和 D 电池的作用和类型详解及其之间的区别

C 和 D 电池是我们日常生活中必不可少的部件。它们通常用于高功率设备。例如手电筒和玩具。 D 型电池和 C 型电池是两种常见的电池类型。它们是一次性圆柱形电池。您可以在很多设备上使用它们。虽然它们有很多相似之处&#xff0c;但它们也有不同的特点。这些特点使它们适合某…

如何用qq邮箱注册outlook邮箱

&#x1f4d1;打牌 &#xff1a; da pai ge的个人主页 &#x1f324;️个人专栏 &#xff1a; da pai ge的博客专栏 ☁️宝剑锋从磨砺出&#xff0c;梅花香自苦寒来 ​ 目录 第一步输入qq邮箱 第二步…

数据类型及数据块认知

西门子STEP7编程语言 梯形图(LAD) 功能块图(FBD) 语句表(STL) 其中梯形图和功能块图可以相互转换 CPU常用数据区 信号输入区 I 信号输出区 Q 程序中表现形式&#xff0c;IX.X/QX.X;IWX/QWX-访问的是CPU输出输入过程映像区 另一种形式IWX:P/QWX:P-访问的是信号端口地址&#xf…

深度整合全球资源,分贝通打造高效、合规的海外差旅管理平台

在全球化商业活动的背景下,中国企业出海已成为常态。然而,随着海外差旅市场的全面增长,企业在海外支出管理上面临诸多挑战。据2023年数据显示,分贝通出海差旅业务GMV同比增长高达500倍,这一增长背后隐藏着企业对于更省钱、更高效管控方式的迫切需求。 面对与日俱增的开支,企业开…

Websocket 替代方案:如何使用 Firestore 监听实时事件

大家好,我是CodeQi! 一位热衷于技术分享的码仔。 ​在现代 Web 开发中,实时更新功能对于许多应用程序(如聊天应用、协作工具和在线游戏)都是必不可少的。虽然 WebSocket 是一种常用的实时通信技术,但 Google 的 Firestore 也提供了一种强大的替代方案,使得实时监听变得…

Golang中defer和return顺序

在Golang中&#xff0c;defer 和 return 的执行顺序是一个重要的特性&#xff0c;它们的执行顺序如下&#xff1a; return语句不是一条单独的语句&#xff0c;实际上&#xff0c;它是由赋值和返回两部分组成的。赋值步骤会先执行&#xff0c;这一步会计算return语句中的表达式…

赛氪网受邀出席浙江省应用数学研究会,共启数学教育与竞赛新篇章

2024年7月5日&#xff0c;浙江省应用数学研究会在风景如画的嘉兴市成功举办了2024年学术研讨会暨第七届第六次理事会工作会议的首日活动。作为技术支持单位&#xff0c;赛氪网受邀参与此次盛会&#xff0c;彰显了其在数学教育及竞赛领域的深厚实力与积极贡献。 开幕式由嘉兴大学…

linux watchdog 子系统

目录 一、watchdog 子系统二、关键数据结构2.1 watchdog_device2.2 watchdog_ops2.3 watchdog_info 三、重要流程3.1 watchdog 初始化3.2 watchdog 设备注册3.3 watchdog 设备文件操作函数3.4 watchdog 喂狗用户空间 watchdog&#xff08;busybox&#xff09;内核空间喂狗疑问 …