用通俗的方法讲解:大模型微调训练详细说明(附理论+实践代码)

本文内容如下

  • 介绍了大模型训练的微调方法,包括prompt tuning、prefix tuning、LoRA、p-tuning和AdaLoRA等。

  • 介绍了使用deepspeed和LoRA进行大模型训练的相关代码。

  • 给出了petals的介绍,它可以将模型划分为多个块,每个用户的机器负责其中一块,分摊了计算压力。

理解篇

prompt tuning

图片

固定预训练参数,为每一个任务额外添加一个或多个embedding,之后拼接query正常输入LLM,并只训练这些embedding。左图为单任务全参数微调,右图为prompt tuning。

图片

  • 标准的T5模型(橙色线)多任务微调实现了强大的性能,但需要为每个任务存储单独的模型副本。

  • prompt tuning也会随着参数量增大而效果变好,同时使得单个冻结模型可重复使用于所有任务。

  • 显著优于使用GPT-3进行fewshot prompt设计。

  • 当参数达到100亿规模与全参数微调方式效果无异。

代码样例:

from peft import PromptTuningConfig, get_peft_model
peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

prefix tuning

图片

prefix tuning依然是固定预训练参数,但除为每一个任务额外添加一个或多个embedding之外,利用多层感知编码prefix,注意多层感知机就是prefix的编码器,不再像prompt tuning继续输入LLM。

embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
transform = torch.nn.Sequential(torch.nn.Linear(token_dim, encoder_hidden_size),torch.nn.Tanh(),torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
)

在三个数据集中prefix和全参数微调的表现对比:

图片

代码样例:

peft_config = PrefixTuningConfig(task_type="CAUSAL_LM", num_virtual_tokens=20)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

LoRA

图片

LoRA冻结了预训练模型的参数,并在每一层decoder中加入dropout+Linear+Conv1d额外的参数

那么,LoRA是否能达到全参数微调的性能呢?

根据实验可知,全参数微调要比LoRA方式好的多,但在低资源的情况下也不失为一种选择

图片

细致到每个任务中的差距如下图:

图片

代码样例:

peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

p-tuning

图片

手动尝试最优的提示无异于大海捞针,于是便有了自动离散提示搜索的方法(作图),但提示是离散的,神经网络是连续的,所以寻找的最优提示可能是次优的。p-tuning依然是固定LLM参数,利用多层感知机和LSTM对prompt进行编码,编码之后与其他向量进行拼接之后正常输入LLM。注意,训练之后只保留prompt编码之后的向量即可,无需保留编码器。

self.lstm_head = torch.nn.LSTM(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=num_layers,dropout=lstm_dropout,bidirectional=True,batch_first=True,)self.mlp_head = torch.nn.Sequential(torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),torch.nn.ReLU(),torch.nn.Linear(self.hidden_size * 2, self.output_size),
)
self.mlp_head(self.lstm_head(input_embeds)[0])

以上代码可清晰展示出prompt编码器的结构。

图片

如上图所示,GPT在P-tuning的加持下可达到甚至超过BERT在NLU领域的性能。下图是细致的对比:图片

MP: Manual prompt

FT: Fine-tuning

MP+FT: Manual prompt augmented fine-tuning

PT: P-tuning

代码样例:

peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

p-tuning v2

图片

p-tuning的问题是在小参数量模型上表现差(如上图所示),于是有了V2版本,类似于LoRA每层都嵌入了新的参数(称之为Deep FT),下图中开源看到p-tuning v2 集合了多种微调方法。p-tuning v2 在多种任务上下进行微调,之后对于不同的任务如token classification与sentence classification添加了随机初始化的任务头(AutoModelForTokenClassification、AutoModelForSequenceClassification),而非使用自然语言的方式,可以说V2是集大成者。

图片

KP: Knowledge Probe,知识探针,用于检测LLM的世界知识掌握能力:https://github.com/facebookresearch/LAMA

SeqTag: Sequence Tagging,如抽取式问答、命名实体识别

Re-param.:Reparameterization,对提示词做单独的编码器

No verb.: No verbalizer,不直接使用LLM head而接一个随机初始化的linear head

以下表格对比了[CLS] label linear head 和 verbalizer with LM head,[CLS] label linear head的方式药略好。

图片

v1到v2的可视化:蓝色部分为参数冻结,橙色部分为可训练部分

图片

下图中对比了FT、PT、PT-2三种方法,粗体为性能最好的,下划线为性能次好的。

图片

代码样例:

peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

AdaLoRA

预训练语言模型中的不同权重参数对下游任务的贡献是不同的。因此需要更加智能地分配参数预算,以便在微调过程中更加高效地更新那些对模型性能贡献较大的参数。

具体来说,通过奇异值分解将权重矩阵分解为增量矩阵,并根据新的重要性度量动态地调整每个增量矩阵中奇异值的大小。这样可以使得在微调过程中只更新那些对模型性能贡献较大或必要的参数,从而提高了模型性能和参数效率。

详细的算法如下:

图片

对比不同方法的性能:

图片

代码样例:

peft_config = AdaLoraConfig(peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],lora_dropout=0.01)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)

代码篇

注:以下代码在pytorch 1.12.1版本下运行,其他包都是最新版本

deepspeed

官方的demo所需要的配置如下:
在这里插入图片描述

注意到官方给的样例单卡V100只能训练13亿规模的模型,如果换成67亿是否能跑起来呢?

按照官方文档搭建环境:

pip install deepspeed>=0.9.0git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-Chat/
pip install -r requirements.txt

请注意如果你之前装了 deepspeed,请更新至0.9.0

试试全参数微调,这毫无疑问OOM

deepspeed --num_gpus 1 main.py \--data_path Dahoas/rm-static \--data_split 2,4,4 \--model_name_or_path facebook/opt-6.5b \--gradient_accumulation_steps 2 \--lora_dim 128 \--zero_stage 0 \--deepspeed \--output_dir $OUTPUT \&> $OUTPUT/training.log

答案是:我们需要卸载,这次便能愉快的run起来了

deepspeed main.py \--data_path Dahoas/rm-static \--data_split 2,4,4 \--model_name_or_path facebook/opt-6.7b \--per_device_train_batch_size 4 \--per_device_eval_batch_size 4 \--max_seq_len 512 \--learning_rate 9.65e-6 \--weight_decay 0.1 \--num_train_epochs 2  \--gradient_accumulation_steps 1 \--lr_scheduler_type cosine \--num_warmup_steps 0 \--seed 1234 \--lora_dim 128 \--gradient_checkpointing \--zero_stage 3 \--deepspeed \--output_dir $OUTPUT_PATH \&> $OUTPUT_PATH/training.log

可以加上LoRA

deepspeed --num_gpus 1 main.py \--data_path Dahoas/rm-static \--data_split 2,4,4 \--model_name_or_path facebook/opt-6.7b \--per_device_train_batch_size 8 \--per_device_eval_batch_size 8 \--max_seq_len 512 \--learning_rate 1e-3 \--weight_decay 0.1 \--num_train_epochs 2 \--gradient_accumulation_steps 16 \--lr_scheduler_type cosine \--num_warmup_steps 0 \--seed 1234 \--gradient_checkpointing \--zero_stage 0 \--lora_dim 128 \--lora_module_name decoder.layers. \--deepspeed \--output_dir $OUTPUT_PATH \&> $OUTPUT_PATH/training.log

peft

以下代码省略了数据处理

初始化

from datasets import load_dataset,load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer,default_data_collator
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_modelMICRO_BATCH_SIZE = 1  
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
EPOCHS = 3  
LEARNING_RATE = 3e-6  
CUTOFF_LEN = 256  
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

模型加载,并使用int8进行训练

model_path = "facebook/opt-6.7b"
output_dir = "model"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=True)
model = prepare_model_for_int8_training(model)  
config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,target_modules=None,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
tokenizer.pad_token_id = 0  
data = load_from_disk("data")

训练与保存

trainer = transformers.Trainer(model=model,train_dataset=data["train"],eval_dataset=data["validation"],args=transformers.TrainingArguments(per_device_train_batch_size=MICRO_BATCH_SIZE,per_device_eval_batch_size=MICRO_BATCH_SIZE,gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,warmup_steps=1000,num_train_epochs=EPOCHS,learning_rate=LEARNING_RATE,# bf16=True,  fp16=True,  logging_steps=1,output_dir=output_dir,save_total_limit=4,),data_collator=default_data_collator,
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=False)
model.save_pretrained(output_dir)

直接这么启动当然会OOM,依然需要卸载

编写accelerate配置文件accelerate.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:gradient_accumulation_steps: 1gradient_clipping: 1.0offload_optimizer_device: noneoffload_param_device: nonezero3_init_flag: truezero3_save_16bit_model: truezero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
dynamo_backend: 'yes'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
use_cpu: true

deepspeed配置文件:ds.json

{"fp16": {"enabled": true,"loss_scale": 0,"loss_scale_window": 500,"initial_scale_power": 16,"hysteresis": 2,"min_loss_scale": 1},"optimizer": {"type": "AdamW","params": {"lr": "auto","betas": "auto","eps": 1e-8,"weight_decay": "auto"}},"scheduler": {"type": "WarmupLR","params": {"warmup_min_lr": 0,"warmup_max_lr": 2e-05,"warmup_num_steps": 0}},"zero_optimization": {"stage": 2,"offload_optimizer": {"device": "cpu","pin_memory": false},"allgather_partitions": true,"allgather_bucket_size": 2e8,"overlap_comm": true,"reduce_scatter": true,"reduce_bucket_size": 2e8,"contiguous_gradients": true},"gradient_accumulation_steps":2,"gradient_clipping": "auto","steps_per_print": 2000,"train_batch_size": 4,"train_micro_batch_size_per_gpu": 1,"wall_clock_breakdown": false
}

启动

accelerate launch --dynamo_backend=nvfuser  --config_file accelearte.yaml finetune.py

注:其他方法与Lora使用方法差距不大,不再赘述,在peft项目中均有代码样例。

顺便提一嘴:petals

图片

petals将模型划分为多个块,每个用户的机器负责其中一块,分摊了计算压力,类似于某磁力链接下载工具,利用hivemind库进行去中心化的训练与推理。当然你也可以创建自己局域网的群组,对自己独有的模型进行分块等自定义操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/190216.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

InsCode实践分享

在当今信息爆炸的时代,如何从海量信息中脱颖而出,获取更多的关注和认可,成为了许多人的共同追求。作为知乎平台上的优质用户,我愿意分享一些自己的经验和技巧,帮助大家更好地运用InsCode,实现个人成长和进步…

【爬虫逆向分析实战】某笔登录算法分析——本地替换分析法

前言 作者最近在做一个收集粉币的项目,可以用来干嘛这里就不展开了😁,需要进行登录换算token从而达到监控收集的作用,手机抓包发现他是通过APP进行计算之后再请求接口的,通过官网分析可能要比APP逆向方便多&#xff0…

01-使用Git操作本地库,如初始化本地库,提交工作区文件到暂存区和本地库,查看版本信息,版本切换命令等

Git的使用 概述 Git是一个分布式版本控制工具, 通常用来管理项目中的源代码文件(Java类、xml文件、html页面等)进行管理,在软件开发过程中被广泛使用 Git可以记录文件修改的历史记录并形成备份从而实现代码回溯, 版本切换, 多人协作, 远程备份的功能Git具有廉价的本地库,方便…

开源图床Qchan本地部署远程访问,轻松打造个人专属轻量级图床

文章目录 前言1. Qchan网站搭建1.1 Qchan下载和安装1.2 Qchan网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar云端设置2.2 Cpolar本地设置 3. 公网访问测试总结 前言 图床作为云存储的一项重要应用场景,在大量开发人员的努力下,已经开发出大…

如果你想成为一名提示词工程师(Prompt Engineer),这款工具你不能错过

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版,欢迎购买。点击进入详情 前言 我们知道,如果想要通过AI得到更好更精确的答案,那么提示词Prompt的好坏至关重要。 因此,提示词工程师这个岗位应运而出。…

第一节:认识微服务

一、微服务技术对比 Dubbo SpringCloudSpringCloudAlibaba注册中心zookeeper、Redis Eureka、ConsulNacos、Eureka服务远程调用Dubbo协议Feign(http协议)Dubbo、Feign配置中心无SpringCloudGateway、ZuulSpringCloudConfig、Nacos服务网…

qemu网络通信

TAP(官网参考地址) TAP,即Tunneling traffic access point,是一种在Linux上使用的虚拟网卡技术,它可以为应用程序提供安全的网络连接。可以利用TAP搭建桥接网络,bridge两端分别为host和qemu虚拟机。 安装…

力扣 790. 多米诺和托米诺平铺(一维dp)

题目描述: 有两种形状的瓷砖:一种是 2 x 1 的多米诺形,另一种是形如 "L" 的托米诺形。两种形状都可以旋转。 给定整数 n ,返回可以平铺 2 x n 的面板的方法的数量。返回对 109 7 取模 的值。 平铺指的是每个正方形都…

具有标记和笔记功能的文件管理器TagSpaces(续)

熟悉老苏的读者都知道,老苏通常只是推荐软件,并简单介绍如何运行它们,而具体的功能则需要读者自行研究。这种方式让老苏能够在工作之余,还能保持每周发布 4 篇的更新。 然而,这种方式也存在明显的缺点。由于老苏没有深…

通义千问 Qwen-7B-Chat-Int4 模型本地化部署

如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 HuggingFace 下载。 以本项目中默认使用的 LLM 模型 THUDM/ChatGLM2-6B 与 Embedding 模型 moka-ai/m3e-base 为例: 下载模型…

WordPress采集器自动采集发布的工具

WordPress作为最受欢迎的内容管理系统之一,其强大的功能和灵活性使其成为许多网站、博客和电子商务平台的首选。WordPress采集器自动采集发布内置采集规则是一项备受关注的功能,让用户可以轻松收集并发布内容。WordPress采集器自动采集发布内置采集规则的…

「Verilog学习笔记」自动贩售机1

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 自动贩售机中可能存在的几种金额:0,0.5,1,1.5,2,2.5,3。然后直接将其作为状态机的几种状…

面试数据库八股文十问十答第二期

面试数据库八股文十问十答第二期 作者:程序员小白条,个人博客 相信看了本文后,对你的面试是有一定帮助的! ⭐点赞⭐收藏⭐不迷路!⭐ 1.MySQL的主从复制 MySQL的主从复制是什么?MySQL主从复制是一种常见的…

11.28~11.29基本二叉树的性质、定义、复习;排序算法;堆

完全二叉树(Complete Binary Tree)是一种特殊的二叉树结构,它具有以下特点: 所有的叶子节点都集中在树的最后两层;最后一层的叶子节点都靠左排列;除了最后一层,其他层的节点数都达到最大值。 …

网络基础:网络通信基础

目录 1.网络通信基本单位 2.网络通信基础 3.调制技术 4.解调技术 5.载波调制 6.编码技术 6.1基本编码 6.2应用型编码 1.曼彻斯特编码 2.差分曼彻斯特编码 3.MLT-3编码 4.mB/nB编码 1.网络通信基本单位 Byte(字节)是用于计量存储容量的一种…

【开发PaaS】基于Postgresql的开发平台Supabase

Supadase是开源的。我们选择可扩展的开源工具,使其易于使用。 Supadase不是Firebase的1对1映射。虽然我们正在构建Firebase提供的许多功能,但我们不会以同样的方式进行: 我们的技术选择大不相同;我们使用的一切都是开源的&#…

xilinx系列FPGA基于VIVADO的pin delay列表生成说明

目录 1 概述2 示例平台3 操作说明4 注意事项 xilinx系列FPGA基于VIVADO的pin delay列表生成说明 1 概述 本文用于讲诉xilinx系列FPGA基于VIVADO的pin delay列表生成说明,以及一些注意事项,为FPGA设计人员探明道路。 Pin delay 即FPGA内部die到pin的延时…

L1-009:N个数求和

目录 ⭐题目描述⭐ ⭐分析 ⭐程序代码 运行结果 ⭐文案分享⭐ ⭐题目描述⭐ 本题的要求很简单,就是求N个数字的和。麻烦的是,这些数字是以有理数分子/分母的形式给出的,你输出的和也必须是有理数的形式。 输入格式: 输入第一行给出…

【Python表白系列】制作一个无法拒绝的表白界面(完整代码)

文章目录 无法拒绝的表白界面环境需求完整代码详细分析系列文章 无法拒绝的表白界面 当点击“不要”时弹出 当点击“”时弹出 环境需求 python3.11.4PyCharm Community Edition 2023.2.5pyinstaller6.2.0(可选,这个库用于打包,使程序没有…

.net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法

文章目录 .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法详细报错内容解决方案修改数据修改表修改字段 .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法 详细报错内容 System.NotSupportedException…