Phi-2小语言模型QLoRA微调教程

前言

就在不久前,微软正式发布了一个 27 亿参数的语言模型——Phi-2。这是一种文本到文本的人工智能程序,具有出色的推理和语言理解能力。同时,微软研究院也在官方 X 平台上声称:“Phi-2 的性能优于其他现有的小型语言模型,但它足够小,可以在笔记本电脑或者移动设备上运行”。

微软通过时下一些如 Big Bench Hard (BBH)、常识推理(PIQA、WinoGrande、ARC easy 和 Challenge、SIQA)、语言理解(HellaSwag、OpenBookQA、MMLU(5-shot)、SQuADv2、BoolQ)、数学(GSM8k)和编码(HumanEval)等基准测试,将 Phi-2 与 7B 和 13B 参数的 Mistral 和 Llama-2 进行了比较。

最终得出仅拥有 27 亿个参数的 Phi-2 ,超越了 Mistral 7B 和 Llama-2 7B 以及 13B 模型的性能。值得注意的是,与大它 25 倍的 Llama-2-70B 模型相比,Phi-2 还在多步推理任务(即编码和数学)上实现了更好的性能。

本教程将基于谜语数据集使用QLoRA 微调Phi-2模型。

模型微调

  • Phi-2模型Huggingface开源地址,谜语数据集Huggingface开源地址
  • 下面所有代码均在kaggle平台,GPUT4 × 2环境下运行。

搭建环境

  • 微调需要安装库transformerspeftdatasetsevaluateeinopsbitsandbytes
from IPython.display import clear_output
!pip install -U transformers peft datasets evaluate einops bitsandbytes -q
!pip install accelerate==0.21.0 -q
clear_output()
  • 导入必要库
import os
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
clear_output()

数据预处理

加载数据

  • 加载数据集,并拆分训练、测试集,打印示例。
dataset = load_dataset("g-ronimo/riddles_evolved")
dataset = dataset["train"].train_test_split(test_size=0.1)
clear_output()
dataset["train"][0]
{'number': 286,'messages': ["Born in despair, fed with mud, soon I'll bloom into beauty for the good.",'This riddle appears to be describing a flower that grows in adverse conditions. Here\'s a step-by-step explanation:\n\n1. "Born in despair" suggests that the subject of the riddle had a difficult start in life.\n2. "Fed with mud" means that it has been nourished in dirty or poor conditions.\n3. "Soon I\'ll bloom into beauty" indicates that despite its challenging start, it will eventually become something beautiful.\n4. "Think step-by-step" is advice to break down complex problems into smaller, manageable parts.\n5. "Keep your explanations simple" is a recommendation to make complex ideas easy to understand.\n6. "Try your very best" emphasizes the importance of giving your all to accomplish your goals.\n7. "If you there is information missing for you to come up with a specific answer, just ask me a short question at the end of your answer" is a request to specify the topic or question if more clarification is needed.\n\nQuestion: What is the answer to the riddle, and what is the identity of the flower described?','What kind of flower starts in hardship, grows in dirt, but ultimately blooms into something lovely?','The answer to the riddle and the identity of the flower described is a sunflower. Sunflowers start as seeds in adversity, grow in soil that may not be ideal, but eventually bloom into beautiful yellow flowers that bring joy and sunlight.']}

应用ChatML格式并标记数据集

  • 首先使用 ChatML 格式组织对话,然后对它们进行标记。设置数据集非常重要,以便模型学会像assistant一样回答,而不是像user一样提出问题。为此,我们将忽略索引的标记(labels)设置为-100。
  • 注意:如果您熟悉QLoRA库,您会注意到原始的guanaco模型是针对问题和答案进行训练的(qlora.py 中默认为train_on_source=False)。这些模型经过原始对话文本的训练,表现良好。不过,在谜语中只关注assistant的部分才是正确的训练方法。
import os
from functools import partial# ChatML格式
templates = ["<|im_start|>assistant\n{msg}<|im_end|>",      # message by assistant"<|im_start|>user\n{msg}<|im_end|>",           # message by user
]# 在损失计算期间,使用此特殊索引来忽略某些标记。
IGNORE_INDEX = -100def tokenize(input, max_length):input_ids, attention_mask, labels = [], [], []# 遍历数据集中的每个消息for i, msg in enumerate(input["messages"]):# 检查消息是来自user还是assistant,应用ChatML模板isHuman = i%2==0msg_chatml = templates[isHuman].format(msg=msg)# 标记化所有内容,稍后截断msg_tokenized = tokenizer(msg_chatml, truncation=False, add_special_tokens=False)# 复制标记和注意力掩码而不进行更改input_ids += msg_tokenized["input_ids"]attention_mask += msg_tokenized["attention_mask"]# 为损失计算调整标签:如果是user->IGNORE_INDEX,如果是assistant->input_ids# 忽略user消息,仅计算assistant消息的损失,因为这是我们想要学习labels += [IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]# 截断至最大长度return {"input_ids": input_ids[:max_length], "attention_mask": attention_mask[:max_length],"labels": labels[:max_length],}dataset_tokenized = dataset.map(# 在1024标记处截断样本# 对于谜题数据集足够了(最大长度1000标记)# 对于其他数据集,必须适应,较高的值需要更多的显存partial(tokenize, max_length=1024), batched = False,# 多线程num_proc = os.cpu_count(),# 删除原始列,不再需要remove_columns = dataset["train"].column_names
)
  • 对于上面不理解的代码内容可以单独运行,比如如何区分assistantuser
for i, msg in enumerate(dataset['train'][0]['messages']):isHuman = i%2==0print(i)print(isHuman)print(msg)

定义collator

  • collate函数的目的是处理和准备用于训练(和评估)的batch数据,关键部分是正确填充输入。它通过使用特定标记填充到最长样本的长度来标准化batch中每个数据点的长度。 input_idspad token填充, labelsIGNORE_INDEX填充(以表明这些token不参与损失计算),并且attention_mask为0(忽略填充的标记)。
# collate函数 - 将字典列表[{input_ids: [123, ..]}, {..]}转换为一个字典
# 形成batch{input_ids: [..], labels: [..], attention_mask: [..]}
def collate(elements):# 从每个元素中提取input_ids,并找出它们中的最大长度tokens = [e["input_ids"] for e in elements]tokens_maxlen = max([len(t) for t in tokens])for i, sample in enumerate(elements):input_ids = sample["input_ids"]labels = sample["labels"]attention_mask = sample["attention_mask"]# 计算需要填充以匹配最大标记长度的填充长度pad_len = tokens_maxlen-len(input_ids)# 用pad标记ID填充'input_ids',用IGNORE_INDEX填充'labels',用0填充'attention_mask'input_ids.extend( pad_len * [tokenizer.pad_token_id] )labels.extend( pad_len * [IGNORE_INDEX] )attention_mask.extend( pad_len * [0] )# 创建并返回包含elements中所有数据的批次batch={"input_ids": torch.tensor( [e["input_ids"] for e in elements] ),"labels": torch.tensor( [e["labels"] for e in elements] ),"attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ),}return batch

微调 Phi-2

加载量化模型

  • 因为在kaggle平台,GPU显存有限,所以只能加载量化后的模型。
  • 加载4-bit模型和分词器(tokenizer
modelpath = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(modelpath,device_map="auto",quantization_config=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_quant_type="nf4",),torch_dtype=torch.bfloat16,trust_remote_code=True,
)

添加ChatML标记

  • ChatML特殊标记添加到模型和tokenizer中。
  • 关于ChatML是一种模型能看的懂的语言格式。
# fast tokenizer有时会忽略添加的tokens
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    # 添加ChatML特殊标记
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))# 调整模型embeddings大小
model.resize_token_embeddings(new_num_tokens=len(tokenizer),pad_to_multiple_of=64)
model.config.eos_token_id = tokenizer.eos_token_id
clear_output()

准备LoRA适配器

  • LoRALow-Rank Adaptation)是微调大型模型的有效方法。它仅在训练期间更新模型的选定部分,从而加快过程并节省内存。
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model# lora微调配置
lora_config = LoraConfig(r=32,lora_alpha=32,target_modules = ['fc1', 'fc2', 'Wqkv', 'out_proj'],lora_dropout=0.1,bias="none",modules_to_save = ["lm_head", "embed_tokens"],task_type="CAUSAL_LM"
)# 添加适配器到模型
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
model = get_peft_model(model, lora_config)
model.config.use_cache = False
  • lora微调配置参数说明:
    • rankLoRA中的rank也会影响可训练参数的数量。较高的rank会增加训练参数,这意味着模型灵活性和适应能力提高,但代价是增加计算复杂性。相反,较低的rank会减少训练参数,意味着更有效的训练和更少的计算负担,但可能会降低模型灵活性。因此,rank的选择代表了模型适应性和计算效率之间的权衡。
    • lora_alpha:缩放因子,用于调整低秩更新对模型原始权重的影响,即:模型原始行为的改变程度。 LoRA 论文指出"tuning alpha is roughly the same as tuning the learning rate"(调整 alpha 与调整学习率大致相同)。关于如何设置ranklora_alpha尚未达成共识。一种方法似乎是设置lora_alpha = r,这就是我们在这里使用的。
    • target_modules:使用上述参数,我们仅训练约 5.1% 的模型权重。若资源有限,也可以选择仅训练注意力矩阵和输出权重( ['Wqkv', 'out_proj']),在rank=32的情况下,参数数量降低到 4.4% 。对线性层进行训练应该会提高模型性能,因为它更接近于完全微调,但也会增加适配器大小。
  • 更多参数说明请访问Huggingface官方文档

开始训练

  • 部分训练超参数说明:
    • batch_size:较大的batch_size更好,但受到可用VRAM的限制。训练样本越长(在tokenization过程中增加 max_length),需要的VRAM就越多。在max_length为1024个token的示例中,batch_size为1是24GB VRAM GPU上的最大值。为了增加有效批量大小, gradient_accumulation_steps设置为16,但缺点是会减慢训练过程。
    • learning_rate2e-5 的学习率对此数据集有不错的效果,当然4e-5的学习率也可能有效,并且会产生一个不错的模型而不会过度拟合。
    • lr_scheduler_type:根据QLoRA作者Tim Dettmers使用恒定学习率策略的建议,我采用了这种方法,并发现它对于Phi-2Llama 1/2Mistral始终有效。
  • 更多训练超参数见官方文档,设置好训练参数后开始训练。
from transformers import TrainingArguments, Trainerbs=1         # batch size
ga_steps=16  # gradient acc. steps
epochs=15
lr=0.00001steps_per_epoch=len(dataset_tokenized["train"])//(bs*ga_steps)args = TrainingArguments(output_dir="out",per_device_train_batch_size=bs,per_device_eval_batch_size=16,evaluation_strategy="steps",logging_steps=2,eval_steps=steps_per_epoch//2,      # eval twice per epochsave_steps=1,         # save once per epochgradient_accumulation_steps=ga_steps,num_train_epochs=epochs,lr_scheduler_type='constant',optim='paged_adamw_32bit',      # val_loss will go NaN with paged_adamw_8bitlearning_rate=lr,group_by_length=False,fp16=True,metric_for_best_model='eval_loss',save_total_limit=1,
#     bf16=False,ddp_find_unused_parameters=False,
)trainer = Trainer(model=model,tokenizer=tokenizer,args=args,data_collator=collate,train_dataset=dataset_tokenized["train"],eval_dataset=dataset_tokenized["test"],
)trainer.train()

训练分析

  • 训练集损失
    请添加图片描述
  • 验证集损失
    请添加图片描述

模型合并

  • LoRA适配器训练完成以后,需要与原模型进行合并。
modelpath = "microsoft/phi-2"
adapter_path='/kaggle/input/phi-2-finetune/out/checkpoint-846'save_to="merged"       base_model = AutoModelForCausalLM.from_pretrained(modelpath,return_dict=True,torch_dtype=torch.bfloat16,device_map="auto",trust_remote_code=True,
)tokenizer = AutoTokenizer.from_pretrained(modelpath)tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
base_model.resize_token_embeddings(new_num_tokens=len(tokenizer),pad_to_multiple_of=64)
base_model.config.eos_token_id = tokenizer.eos_token_idmodel = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()model.save_pretrained(save_to, safe_serialization=True, max_shard_size='4GB')
tokenizer.save_pretrained(save_to)clear_output()

微调前后对比

  • 先加载一下原模型,输入谜语,看看回答是什么。
torch.set_default_device("cuda")model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)inputs = tokenizer('''What makes a noise like a bell and flies, but cannot be seen? The answer lies in the bright blue sky.''', return_tensors="pt", return_attention_mask=False)outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]clear_output()
print(text)

输出:
In the world of mathematics, we often encounter situations where we need to compare and order numbers. This skill is essential in various fields, including science, engineering, and even everyday life. Let’s explore the concept of comparing and ordering numbers using the tones of science, specifically the principles of physics and the states of matter.

Imagine you are in a science lab, conducting an experiment to study the behavior of different substances. You have a set of test tubes filled with various liquids, each representing a different state of matter. The liquids in the test tubes are like numbers, and we can compare and order them based on their properties.
参考中译:
在数学世界中,我们经常会遇到需要对数字进行比较和排序的情况。这项技能在科学、工程甚至日常生活等各个领域都至关重要。让我们用科学的视角,特别是物理学原理和物质状态,来探讨数字比较和排序的概念。

想象一下,您正在科学实验室里进行一项实验,研究不同物质的行为。你有一组试管,里面装满了各种液体,每种液体代表一种不同的物质状态。试管中的液体就像数字,我们可以根据它们的性质进行比较和排序。

  • 可以说是非常糟糕的回答,我们看看微调后的模型会输出什么。
model = AutoModelForCausalLM.from_pretrained("/kaggle/working/merged", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("/kaggle/working/merged", trust_remote_code=True)inputs = tokenizer('''<|im_start|>What makes a noise like a bell and flies, but cannot be seen? The answer lies in the bright blue sky.<|im_end|>''', return_tensors="pt", return_attention_mask=False)outputs = model.generate(**inputs, max_length=300)
text = tokenizer.batch_decode(outputs)[0]clear_output()
print(text)

输出:
The answer to the riddle is a “bluebird.” Bluebirds make a distinctive bell-like sound with their wings, and they are often seen flying in the sky. However, they cannot be seen with the naked eye as they are small birds. If you need more information, please let me know what specific aspect of the answer you would like to know.
参考中译:
谜底是 “青鸟”。青鸟用翅膀发出独特的铃铛声,人们经常看到它们在天空中飞翔。不过,由于它们是小型鸟类,肉眼无法看到。如果您需要更多信息,请告诉我您想知道答案的具体方面。

  • 微调后的模型得到了相对满意的答案。请注意,这是在4-bit量化状态下微调的答案,如果可以在float32状态下微调,或许会得到更好的答案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/609945.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安卓(雷电)模拟器清除屏幕密码

1、设置磁盘可写 启动模拟器&#xff0c;然后在模拟器的设置界面&#xff0c;设置磁盘共享为可写入&#xff0c;重启模拟器&#xff0c;如下图&#xff1a; 2、找到模拟器目录 返回桌面&#xff0c;右键模拟器图标&#xff0c;打开文件所在目录&#xff0c;如下图&#xff1a…

javaWebssh校园物业管理系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

一、源码特点 java ssh校园物业管理系统是一套完善的web设计系统&#xff08;系统采用ssh框架进行设计开发&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用 B/S模式开发。开发环境为TOMCAT7.…

如何申请api接口,快速对接数据源

申请API接口并快速对接数据源通常需要以下步骤&#xff1a; 寻找合适的API供应商&#xff1a;首先需要找到提供所需数据的API供应商&#xff0c;可以通过搜索引擎或者专业的API市场找到合适的API接口服务提供商。 注册并获取API密钥&#xff1a;在供应商的网站上注册账户&…

蓝牙物联网多个核心应用场景开发与应用细化分析

蓝牙物联网是指利用蓝牙技术将物理设备与互联网连接起来&#xff0c;实现设备之间的信息共享与互通。蓝牙物联网在各个领域得到了广泛应用&#xff0c;并且在未来有着巨大的发展潜力。本文将围绕蓝牙物联网的五大核心应用场景进行介绍&#xff0c;包括智能家居、智能健康、智能…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《考虑电氢耦合和碳交易的电氢能源系统置信间隙鲁棒规划》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 这标题涉及到一个复杂的能源系统规划问题&#xff0c;其中考虑了电氢耦合、碳交易和置信间隙鲁棒规划。以下是对标题各个部分的解读&#xff1a; 电氢耦…

矩阵中的最长递增路径

题目链接 矩阵中的最长递增路径 题目描述 注意点 不能 在 对角线 方向上移动或移动到 边界外&#xff08;即不允许环绕&#xff09; 解答思路 因为最长递增路径一定是连续的&#xff0c;所以想到使用深度优先遍历来做。如果只使用深度优先遍历会导致超时&#xff08;同一个…

MT6785安卓核心板_联发科MTK6785/Helio G95/曦力G95核心板定制

MT6785安卓核心板是基于MT6785(Helio G95)处理器&#xff0c;具备八核处理器结构&#xff0c;包括2颗主频为2.05GHz的Cortex A76处理器和6颗主频为2.0GHz的Cortex A55处理器&#xff0c;以及六颗Cortex-A55处理器。而在GPU方面&#xff0c;采用了Arm Mali-G76 MC4&#xff0c;频…

ESP32-Touch(Arduino)

Touch Touch传感器是一种外围设备&#xff0c;具有内部振荡器电路&#xff0c;可在固定时间段内测量相应GPIO引脚上的充电/放电频率。 因此&#xff0c;这些触摸传感器也被称为电容式传感器。例如&#xff0c;如果您触摸这些引脚中的任何一个&#xff0c;手指电荷将改变这个周…

MATHPILE:一个高质量的大规模的数学语料库

简介 MATHPILE&#xff1a;一个高质量、大规模的数学语料库&#xff0c;29 GB&#xff0c;包含约 95 亿个token。涵盖从 K-12 到大学、研究生水平和数学竞赛的内容&#xff0c;包括高质量教科书、讲义、科学论文等。提供详细的数据记录&#xff0c;包括数据集表格和质量注释&a…

渐变登录页

效果演示 实现了一个简单的登录页面的样式和交互效果。 Code <div class"flex"><div class"login color">Login</div><label class"color">Username :</label><input type"text" class"input&…

已安装MySQL5.7的基础上安装MySQL8教程

类似文章很多&#xff0c;但部分问题解决方案并不是很完整&#xff0c;且对细节描述不够清楚&#xff0c;特意总结一篇 在本机已经安装MySQL5.7的情况下新安装MySQL8.x的方案如下&#xff08;请按照步骤详细操作&#xff09;&#xff1a; 1.进入官网下载 https://dev.mysql.c…

【Emgu.CV教程】4.3、无缝融合应用之SeamlessClone()

SeamlessClone()函数才是真正的无缝克隆&#xff0c;它可以将一张小一点的图片&#xff0c;复制到另一张大一点的图片中&#xff0c;并且复制的位置可以用户自己定义&#xff0c;先看一下它的函数介绍&#xff1a; public static void SeamlessClone(IInputArray src, // 输入…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -投票创建后端实现

锋哥原创的uniapp微信小程序投票系统实战&#xff1a; uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

phpcms v9后台添加草稿箱功能

一、后台添加文章模板phpcms/modules/content/templates/content_add.tpl.php中94行增加”保存草稿“按钮&#xff1a; <div class"button"><input value"<?php echo L(save_draft);?>" type"submit" name"dosubmit_draf…

读算法霸权笔记13_读后总结与感想兼导读

1. 基本信息 算法霸权&#xff1a;数学杀伤性武器的威胁 [美] 凯西奥尼尔(Cathy 著 中信出版社,2018年9月出版 1.1. 读薄率 书籍总字数220千字&#xff0c;笔记总字数32359字。 读薄率32359220000≈14.71% 1.2. 读厚方向 算法的力量&#xff1a;人类如何共同生存&#x…

阻塞队列(JAVA)

阻塞队列是一种特殊的队列&#xff0c;也遵守 "先进先出" 的原则。 阻塞队列能是一种线程安全的数据结构, 并且具有以下特性: 当队列满的时候, 继续入队列就会阻塞, 直到有其他线程从队列中取走元素&#xff1b;当队列空的时候, 继续出队列也会阻塞, 直到有其他线程往…

【WinForm.NET开发】Windows窗体设计器错误页

本文内容 黄色栏此错误的实例有关此错误的帮助有关此错误的论坛帖子常见设计时错误 如果 Windows 窗体设计器由于代码、第三方组件或其他位置的错误而未能加载&#xff0c;将显示错误页而不是设计器。 此错误页不一定表示设计器中的 bug。 bug 可能位于代码隐藏文件中的某个位…

STM32F4XX的12位ADC采集数值超过4096右对齐模式设置失败

文章目录 一、前言二、问题1&#xff1a;数值超过4096三、问题1的排错过程四、问题2&#xff1a;右对齐模式设置失败五、问题2的解决方法5.1 将ADC_ExternalTrigConv设置为05.2 使用ADC_StructInit()函数 一、前言 最近在学习STM32的ADC功能&#xff0c;遇到了一个奇怪的问题。…

(一)Spring Cloud 直击微服务作用、架构应用、hystrix降级

直击微服务作用 微服务架构: 遇到了什么问题? 将单体架构拆分成微服务架构后,如果保证多个服务(项目)正常运行? 哪个技术可以解决这个问题? 微服务技术 服务治理: 服务管理,维护服务与服务之间的关系 这个技术如何使用? netflix/网…

【研究僧毕业总结】第1024个创作日

目录 前言1. 机缘2. 收获3. 憧憬 前言 收到这封来信&#xff0c;代表从创作至今刚好满足1024天 1024&#xff0c;程序员的记忆 1. 机缘 从学生到社会&#xff0c;都在需求一个记录笔记的软件&#xff0c;而作为程序员&#xff0c;CSDN可云同步又可直接在云平台上看到 选择了…