参数高效微调PEFT(一)快速入门BitFit、Prompt Tuning、Prefix Tuning

参数高效微调PEFT(一)快速入门BitFit、Prompt Tuning、Prefix Tuning

目前,模型最全的网站是HuggingFace,但是国内需要魔法流量才能访问。另外,现在大模型权重文件都较大,也会浪费不少流量,因此这里推荐使用魔搭社区下载模型文件,下面以Llama-2-7b模型为示例演示,其他模型只需换个模型名称(model_id)即可

模型库首页 · 魔搭社区 (modelscope.cn)

# 利用魔搭社区下载模型文件
# https://modelscope.cn/models# 1、利用conda新建环境
(base) C:\Users\Undo>conda create -n ms python=3.9 -y# 2、激活环境
(base) C:\Users\Undo>conda activate ms# 3、安装modelscope、jupyterlab
(ms) C:\Users\Undo>pip install modelscope jupyterlab
  • .safetensors 侧重于安全性和效率,适合于那些希望快速部署且对安全有较高要求的场景,尤其在Hugging Face生态中。
  • .ckpt 文件是 PyTorch Lightning 框架采用的模型存储格式,它不仅包含了模型参数,还包括优化器状态以及可能的训练元数据信息,使得用户可以无缝地恢复训练或执行推理。
  • .bin 文件不是标准化的模型保存格式,但在某些情况下可用于存储原始二进制权重数据,加载时需额外处理。
  • .pth 是PyTorch的标准模型保存格式,方便模型的持久化和复用,支持完整模型结构和参数的保存与恢复。

在这里插入图片描述

# 4、利用modelscope hub下载模型
from modelscope.hub.snapshot_download import snapshot_download# 我们这里只下载.safetensors的权重文件
# 可以看到下载速度还是很快的
snapshot_download(model_id='modelscope/Llama-2-7b-ms',           # 需要下载的模型cache_dir=r'D:\python\models\model-download',  # 缓存到本地路径ignore_file_pattern=['.bin']                   # 不需要下载的文件)
Downloading: 100%|██████████| 21.0/21.0 [00:00<?, ?B/s]
Downloading: 100%|██████████| 583/583 [00:00<00:00, 38.4kB/s]
Downloading: 100%|██████████| 183/183 [00:00<?, ?B/s] 
Downloading: 100%|██████████| 179/179 [00:00<?, ?B/s] 
Downloading: 100%|██████████| 6.86k/6.86k [00:00<?, ?B/s]
Downloading: 100%|█████████▉| 9.29G/9.29G [09:40<00:00, 17.2MB/s]
Downloading: 100%|█████████▉| 3.26G/3.26G [03:11<00:00, 18.3MB/s]
Downloading: 100%|██████████| 26.2k/26.2k [00:00<00:00, 572kB/s]
Downloading: 100%|██████████| 1.67k/1.67k [00:00<00:00, 113kB/s]
Downloading: 100%|██████████| 12.8k/12.8k [00:00<00:00, 425kB/s]
Downloading: 100%|██████████| 1.20M/1.20M [00:00<00:00, 3.65MB/s]
Downloading: 100%|██████████| 435/435 [00:00<00:00, 28.6kB/s]
Downloading: 100%|██████████| 1.76M/1.76M [00:00<00:00, 3.33MB/s]
Downloading: 100%|██████████| 488k/488k [00:00<00:00, 2.46MB/s]
Downloading: 100%|██████████| 746/746 [00:00<?, ?B/s] 
Downloading: 100%|██████████| 4.65k/4.65k [00:00<00:00, 315kB/s]

1 参数高效微调(PEFT)简介

1.1 预训练语言模型 + 下游任务微调

  • 随着Transformer在2017年发布后,2018年谷歌又发布了BERT模型;
  • Bert的结构如下图所示,左边是Bert模型预训练过程,右边是对于具体任务的微调过程。
    • 其中,微调阶段是后续用于一些下游任务的时候进行微调,例如:文本分类,词性标注,问答系统
    • BERT 无需调整结构就可以在不同的任务上进行微调。
    • 通过预训练语言模型 + 下游任务微调的任务设计,带来了强大的模型效果。从此,预训练语言模型 + 下游任务微调便成为了 NLP 领域主流训练范式。

在这里插入图片描述

  • 但是,随着模型变得越来越大(以GPT为代表的预训练语言模型),在消费级硬件上对模型进行全部参数的微调(full fine-tuning)变得不可行。
  • 此外,为每个下游任务独立存储和部署微调模型变得非常昂贵,因为微调模型(调整模型的所有参数)与原始预训练模型的大小相同

1.2 PEFT的简介

  • 在上述的情况下,参数高效微调 (Parameter-efficient fine-tuning,PEFT)应运而生。
  • 参数高效微调方法仅对模型的一小部分参数(这一小部分可能是模型自身的,也可能是外部引入的)进行训练便可以为模型带来显著的性能变化,一些场景下甚至不输于全量微调。
  • 由于训练一小部分参数,极大程度降低了训练大模型的算力需求,不需要多机多卡,单卡即可完成对一些大模型的训练,不仅如此,少量的训练参数对存储的要求同样降低了很多,大多数的参数高效微调方法只需要保存训练部分的参数,与动辄几十GB的原始大模型相比,几乎可以忽略。
  • 除此之外,模型全量微调还会损失多样性,存在灾难性遗忘的问题。
  • 如下图所示,高效微调技术可以粗略分为以下三大类:
    • 增加额外参数(Additive)。而在增加额外参数这类方法中,又主要分为:
      • 类适配器(Adapter-like)方法
      • 软提示(Soft prompts)
    • 选取一部分参数更新(Selective)
    • 引入重参数化(Reparametrization-based)

在这里插入图片描述

2 全量微调bloom模型

2.1 生成式问答机器人

数据集:https://huggingface.co/datasets/shibing624/alpaca-zh

预训练模型:澜舟科技开源的Bloom预训练生成模型-中文-389m(langboat/bloom-389m-zh)。如果机器GPU内存比较大,可以下载langboat/bloom-1b4-zh模型。

模型魔搭社区链接:

Bloom预训练生成模型-中文-389m · 模型库 (modelscope.cn)

在这里插入图片描述

我们使用此模型,先全量微调基于生成式的问答机器人。

  • 预训练任务类型是因果语言模型(自回归模型),因此我们需要AutoModelForCausalLM
  • 基于上文的token预测当前token,需要注意的是结束位置要有特殊token,即eos_token。

在这里插入图片描述

  • 我们基于指令微调的方式,赋予回答问题的能力。将输入作为Prompt,我们这里基于单轮问答模型,在计算Loss时只计算Output部分,因此将输入的Label设置为-100。

  • 指令微调通常涉及将模型训练为根据给定的指令执行特定任务。以下是一些用于指令微调的数据示例,每个示例包括一个指令和相应的期望输出:

  • 1、摘要:
    指令(Instruction):对以下文章进行摘要。
    文章(Input):(一段较长的文本)
    期望输出(Output):(文章的简短摘要)2、情感分析:
    指令(Instruction):分析以下句子的情感倾向。
    句子(Input):我今天感觉非常开心!
    期望输出(Output):正面3、问答系统:
    指令(Instruction):回答以下问题。
    问题(Input):太阳系中的行星有哪些?
    期望输出(Output):太阳系中的行星包括水星、金星、地球、火星、木星、土星、天王星和海王星。4、文本分类:
    指令(Instruction):将以下文本分类为“科技”、“体育”或“艺术”中的一个类别。
    文本(Input):苹果公司发布了最新款的 iPhone。
    期望输出(Output):科技5、语言翻译:
    指令(Instruction):将以下句子从英语翻译成法语。
    句子(Input):Hello, how are you?
    期望输出(Output):Bonjour, comment ça va?
    

在这里插入图片描述

2.1 模型的全量微调

2.1.1 加载数据集

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainerds = Dataset.load_from_disk("./alpaca_data_zh/")
ds[:3]
{'output': ['以下是保持健康的三个提示:\n\n1. 保持身体活动。每天做适当的身体运动,如散步、跑步或游泳,能促进心血管健康,增强肌肉力量,并有助于减少体重。\n\n2. 均衡饮食。每天食用新鲜的蔬菜、水果、全谷物和脂肪含量低的蛋白质食物,避免高糖、高脂肪和加工食品,以保持健康的饮食习惯。\n\n3. 睡眠充足。睡眠对人体健康至关重要,成年人每天应保证 7-8 小时的睡眠。良好的睡眠有助于减轻压力,促进身体恢复,并提高注意力和记忆力。','4/16等于1/4是因为我们可以约分分子分母都除以他们的最大公约数4,得到(4÷4)/ (16÷4)=1/4。分数的约分是用分子和分母除以相同的非零整数,来表示分数的一个相同的值,这因为分数实际上表示了分子除以分母,所以即使两个数同时除以同一个非零整数,分数的值也不会改变。所以4/16 和1/4是两种不同的书写形式,但它们的值相等。','朱利叶斯·凯撒,又称尤利乌斯·恺撒(Julius Caesar)是古罗马的政治家、军事家和作家。他于公元前44年3月15日被刺杀。 \n\n根据历史记载,当时罗马元老院里一些参议员联合起来策划了对恺撒的刺杀行动,因为他们担心恺撒的统治将给罗马共和制带来威胁。在公元前44年3月15日(又称“3月的艾达之日”),恺撒去参加元老院会议时,被一群参议员包围并被攻击致死。据记载,他身中23刀,其中一刀最终致命。'],'input': ['', '输入:4/16', ''],'instruction': ['保持健康的三个提示。', '解释为什么以下分数等同于1/4', '朱利叶斯·凯撒是如何死亡的?']
}

2.1.2 数据集预处理

mode_path = r'/root/autodl-fs/models/langboat/bloom-389m-zh'# tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")
tokenizer = AutoTokenizer.from_pretrained(mode_path)tokenizer
BloomTokenizerFast(name_or_path='/root/autodl-fs/models/langboat/bloom-389m-zh', # 模型路径vocab_size=42437,         # 词典大小model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left',      # 在左边进行填充truncation_side='right',  # 在右边进行截断special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'},        # 特殊tokenclean_up_tokenization_spaces=False
)
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []# 将prompt进行tokenize,这里我们没有利用tokenizer进行填充和截断# 这里我们自己进行截断,在DataLoader的collate_fn函数中进行填充instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")# 将output进行tokenize,注意添加eos_tokenresponse = tokenizer(example["output"] + tokenizer.eos_token)# 将instruction + output组合为inputinput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]# prompt设置为-100,不计算losslabels = [-100] * len(instruction["input_ids"]) + response["input_ids"]# 设置最大长度,进行截断if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds
Dataset({features: ['input_ids', 'attention_mask', 'labels'],num_rows: 26858
})
# 展示两条数据
print(tokenized_ds[:2])
{'input_ids': [[23069, 29, 210, 6583, 24772, 8995, 13533, 671, 189, 4122, 15263, 29, 210, 4744, 583, 6583, 24772, 8995, 13533, 1022, 189, 189, 20, 17, 210, 6583, 8416, 3228, 420, 8634, 1900, 13648, 8416, 5625, 355, 1202, 29011, 553, 30355, 1298, 15599, 355, 961, 4872, 34650, 5980, 355, 10915, 15342, 7761, 355, 1403, 11472, 6189, 20465, 671, 189, 21, 17, 210, 20122, 13660, 420, 8634, 13869, 20189, 373, 17070, 553, 16382, 553, 1204, 6165, 1430, 641, 14562, 16130, 24251, 15502, 7984, 355, 7981, 1220, 6538, 553, 1220, 14562, 641, 13545, 10249, 355, 714, 6583, 24772, 13660, 11297, 671, 189, 22, 17, 210, 17672, 16272, 420, 17672, 1063, 13966, 5980, 18688, 355, 30645, 8634, 1638, 7900, 954, 3779, 210, 38858, 17672, 420, 14054, 17672, 11472, 15375, 10891, 355, 4872, 8416, 7442, 355, 1403, 5323, 4001, 16885, 14721, 1249, 420, 2], [23069, 29, 210, 8254, 6744, 4744, 25703, 24676, 937, 27599, 189, 8996, 1022, 10273, 1323, 189, 189, 4122, 15263, 29, 210, 10273, 1323, 14359, 27599, 12675, 15295, 1714, 1008, 8460, 1008, 3101, 1266, 30984, 2388, 20263, 3899, 1665, 23, 355, 4290, 24270, 15445, 36981, 18, 375, 1323, 15445, 36981, 32, 27599, 420, 1008, 17620, 1714, 1008, 20576, 8460, 641, 1008, 3101, 30984, 14855, 1717, 7330, 35941, 355, 1042, 3648, 25703, 6317, 14855, 3469, 355, 857, 3047, 25703, 11490, 3648, 657, 8460, 30984, 1008, 3101, 355, 3293, 9043, 5101, 1665, 4232, 30984, 26919, 1717, 7330, 35941, 355, 1008, 17620, 3469, 18476, 8041, 420, 3293, 10273, 1323, 9623, 27599, 583, 11845, 7153, 28585, 5339, 355, 1379, 16030, 3469, 28807, 420, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 4744, 583, 6583, 24772, 8995, 13533, 1022, 189, 189, 20, 17, 210, 6583, 8416, 3228, 420, 8634, 1900, 13648, 8416, 5625, 355, 1202, 29011, 553, 30355, 1298, 15599, 355, 961, 4872, 34650, 5980, 355, 10915, 15342, 7761, 355, 1403, 11472, 6189, 20465, 671, 189, 21, 17, 210, 20122, 13660, 420, 8634, 13869, 20189, 373, 17070, 553, 16382, 553, 1204, 6165, 1430, 641, 14562, 16130, 24251, 15502, 7984, 355, 7981, 1220, 6538, 553, 1220, 14562, 641, 13545, 10249, 355, 714, 6583, 24772, 13660, 11297, 671, 189, 22, 17, 210, 17672, 16272, 420, 17672, 1063, 13966, 5980, 18688, 355, 30645, 8634, 1638, 7900, 954, 3779, 210, 38858, 17672, 420, 14054, 17672, 11472, 15375, 10891, 355, 4872, 8416, 7442, 355, 1403, 5323, 4001, 16885, 14721, 1249, 420, 2], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 10273, 1323, 14359, 27599, 12675, 15295, 1714, 1008, 8460, 1008, 3101, 1266, 30984, 2388, 20263, 3899, 1665, 23, 355, 4290, 24270, 15445, 36981, 18, 375, 1323, 15445, 36981, 32, 27599, 420, 1008, 17620, 1714, 1008, 20576, 8460, 641, 1008, 3101, 30984, 14855, 1717, 7330, 35941, 355, 1042, 3648, 25703, 6317, 14855, 3469, 355, 857, 3047, 25703, 11490, 3648, 657, 8460, 30984, 1008, 3101, 355, 3293, 9043, 5101, 1665, 4232, 30984, 26919, 1717, 7330, 35941, 355, 1008, 17620, 3469, 18476, 8041, 420, 3293, 10273, 1323, 9623, 27599, 583, 11845, 7153, 28585, 5339, 355, 1379, 16030, 3469, 28807, 420, 2]]}
# 解码一条数据验证预处理是否成功
print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
'Human: 解释为什么以下分数等同于1/4\n输入:4/16\n\nAssistant: 4/16等于1/4是因为我们可以约分分子分母都除以他们的最大公约数4,得到(4÷4)/ (16÷4)=1/4。分数的约分是用分子和分母除以相同的非零整数,来表示分数的一个相同的值,这因为分数实际上表示了分子除以分母,所以即使两个数同时除以同一个非零整数,分数的值也不会改变。所以4/16 和1/4是两种不同的书写形式,但它们的值相等。</s>''4/16等于1/4是因为我们可以约分分子分母都除以他们的最大公约数4,得到(4÷4)/ (16÷4)=1/4。分数的约分是用分子和分母除以相同的非零整数,来表示分数的一个相同的值,这因为分数实际上表示了分子除以分母,所以即使两个数同时除以同一个非零整数,分数的值也不会改变。所以4/16 和1/4是两种不同的书写形式,但它们的值相等。</s>'

2.1.3 创建模型、配置训练器

# model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")
model = AutoModelForCausalLM.from_pretrained(mode_path)# 在训练之前,我们先推理一下
from transformers import pipelinepipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, )
[{'generated_text': "Human: 考试有哪些技巧?\n\nAssistant: 考试技巧:\n(上)\nInglis skills (上)\n(下)\nIn other words I have a lot of experience in English\nI have a lot of experience with our children\nI have a lot of experience with our students\nI have more experience in the English of the world\nWhen did you get out of college?\nI feel an sense of helth in my country\nI feel like I am a big man in someone\nI feel like I am in there and I'm not going away\nI see the bad that I live in\nI see the bad that is done around me\nI see the bad that was done my family and I have to live with it to the end\nBut also I see the good that I still have to live with it and give it an ending\nI see the good that the bad that is done in me may come right back to me\nAnd the bad that was made to me so that I may live with it\nIt can stop me not being the same\nIt can stop me not being someone that I need to be\nAnd I can live with a bad that made me to feel like a big man in someone\nI"}]

显存消耗分析

模型参数:386m 大约 0.38G

模型占用:默认是fp32,占用4个字节,0.38G * 4 = 1.52G

梯度占用:0.38G * 4 = 1.52G

优化器:0.38G * 4 * 2 = 3.04G

一共占用显存:6.08G

  • 这里我们为了显存接近计算的显存,先设置batch_size=1,并且不进行梯度累加
  • 对于后面高效微调方法,对比显存占用时同样设置batch_size=1,并且不进行梯度累加
# 我们为了显存接近计算的显存,设置batch_size=1,并且不进行梯度累加
args = TrainingArguments(output_dir="./chatbot",         # 输出文件夹per_device_train_batch_size=4,  # 训练时的batch_sizegradient_accumulation_steps=8,  # 梯度累积次数,相当于batch_size=32logging_steps=100,              # log 打印的频率num_train_epochs=1              # 训练轮数
)trainer = Trainer(model=model,    # 预训练模型args=args,      # 训练参数train_dataset=tokenized_ds,  # 训练集data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) 
)
# 我们可以看到消耗内存 6884MiB = 6.72G
(base) root@autodl-container-adbc11ae52-f2ebff02:~# nvidia-smi 
Sat May 25 10:19:17 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B5:00.0 Off |                  N/A |
| 34%   59C    P2   182W / 250W |   6884MiB / 11264MiB |     72%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
  • 我们测试后,为了加快训练速度,设置batch_size为4,并设置梯度累积次数为8
  • 后面高效微调方法真正训练时,我们也这样设置,加快训练速度。
args = TrainingArguments(output_dir="./chatbot",         # 输出文件夹per_device_train_batch_size=4,  # 训练时的batch_sizegradient_accumulation_steps=8,  # 梯度累积次数,相当于batch_size=32logging_steps=100,              # log 打印的频率num_train_epochs=1              # 训练轮数
)trainer = Trainer(model=model,    # 预训练模型args=args,      # 训练参数train_dataset=tokenized_ds,  # 训练集data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) 
)
(base) root@autodl-container-adbc11ae52-f2ebff02:~# nvidia-smi 
Sat May 25 10:21:24 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B5:00.0 Off |                  N/A |
| 34%   58C    P2   225W / 250W |   8636MiB / 11264MiB |     96%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

2.1.4 模型训练及推理

# 模型训练
trainer.train()

在这里插入图片描述

# 我们再次推理,可以发现效果变好了
from transformers import pipelinepipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=256, do_sample=True, )
[{'generated_text': 'Human: 考试有哪些技巧?\n\nAssistant: 考试有着不同的技巧,这些技巧可以帮助学生在考试中取得好成绩。不同的考试需要不同的学习方法,不同的学生应该利用不同的学习方法来学习。每种考试都有其具体的步骤和考试选项,每种考试都有自己独特的特点和考试要求。这些考试技巧可以有效地帮助学生理解考试的需求和不适,并提高他们的学习效率。\n\n考试的技巧和学习方法是一对一的,这有助于学生提高学习能力、保持良好学习习惯并提高学业成绩。考试的指导和考试题海量且丰富,这可以帮助学生在学习中不断提高自我,提高求学能力,并为未来考试做好准备。此外,考试的每套答案都是有答案的,这有助于学生更好地掌握考试技巧。'}]

3 BitFit

3.1 BitFit概述

  • BitFit是一种稀疏的微调方法,它训练时只更新bias的参数或者部分bias参数

  • 论文地址:2106.10199v2 (arxiv.org)

  • 对于BERT encoder模型而言,只更新bias参数跟特定任务的分类层参数。

    • 涉及到的bias参数有attention模块中计算query,key,value;

    在这里插入图片描述

    • MLP层中的bias;
    • Layernormalization层的bias参数。

在这里插入图片描述

  • 在Bert-Base/Bert-Large这种模型里,bias参数仅占模型全部参数量的0.08%~0.09%。

    • 如下图,BitFit在参数量远小于Adapter、Diff-Pruning的情况下,效果与Adapter、Diff-Pruning相当,甚至在某些任务上略优于Adapter、Diff-Pruning。

    在这里插入图片描述

    • 如下图所示,虽不及全量参数微调,但是远超固定全部模型参数的Frozen方式。另外,发现计算query和将特征维度从N放大到4N的FFN层(intermediate)的bias参数变化最为明显,只更新这两类bias参数也能达到不错的效果。

在这里插入图片描述

3.2 BitFit轻量微调bloom模型

HuggingFace上有一个专门用于高效微调的库peft,地址:https://github.com/huggingface/peft

我们可以直接通过pip进行安装:

pip install peft

BitFit轻量微调bloom模型步骤很简单,peft库并未集成此方法,我们可以手动实现。

  • 加载数据集、数据集预处理和2.1中一样。
  • 创建模型后,我们需要进行BitFit微调。
# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, low_cpu_mem_usage=True)# 模型的参数量
print(sum(param.numel() for param in model.parameters()))# 加载模型后,我们可以进行BitFit微调
num_param = 0
for name, param in model.named_parameters():if "bias" not in name:param.requires_grad = False # 非baias的参数进行冻结else:num_param += param.numel()  # 272384
print(num_param)
# 0.0007877630195608073
print(num_param / sum(param.numel() for param in model.parameters()))
  • 配置训练器、模型训练及推理和2.1一样。

  • 显存消耗情况:

# 设置batch_size=1,并且不进行梯度累加
(base) root@autodl-container-adbc11ae52-f2ebff02:~/autodl-fs/models/langboat/bloom-389m-zh# nvidia-smi 
Mon May 27 17:42:23 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B4:00.0 Off |                  N/A |
| 34%   60C    P2   176W / 250W |   2674MiB / 11264MiB |     40%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

4 Prefix Tuning

4.1 Prefix Tuning概述

  • 论文链接:2101.00190 (arxiv.org)
  • 传统的微调范式利用预训练模型去对不同的下游任务进行微调,对每个任务都要保存一份微调后的模型权重,一方面微调整个模型耗时长;另一方面也会占很多存储空间。
  • 在Prefix Tuning之前的工作主要是人工设计离散的模版或者自动化搜索离散的模版。
    • 对于人工设计的模版,模版的变化对模型最终的性能特别敏感,加一个词、少一个词或者变动位置都会造成比较大的变化。
    • 而对于自动化搜索模版,成本也比较高;同时,以前这种离散化的token搜索出来的结果可能并不是最优的。
  • 基于上述两点,Prefix Tuning提出为LM添加可训练任务特定的前缀,这样就可以为不同任务保存不同的前缀,微调成本也小(如下图所示);
  • 这种Prefix实际就是连续可微的Virtual Token(Soft Prompt),相比离散的Token,更好优化,效果更好。

在这里插入图片描述

  • 如下图,针对不同的模型结构,需要构造不同的Prefix。

    • 针对自回归架构模型:在句子前面添加前缀,得到 z = [PREFIX; x; y],合适的上文能够在固定 LM 的情况下去引导生成下文。
    • 针对编码器-解码器架构模型:Encoder和Decoder都增加了前缀,得到 z = [PREFIX; x; PREFIX0; y]。Encoder端增加前缀是为了引导输入部分的编码,Decoder 端增加前缀是为了引导后续token的生成。

    在这里插入图片描述

  • 该方法其实和构造Prompt类似,只是Prompt是人为构造的“显式”的提示,并且无法更新参数,而Prefix则是可以学习的“隐式”的提示。

  • 如下图,为了防止直接更新Prefix的参数导致训练不稳定和性能下降的情况,在Prefix层前面加了MLP结构,训练完成后,只保留Prefix的参数。除此之外,通过消融实验证实,只调整embedding层的表现力不够,将导致性能显著下降,因此,在每层都加了prompt的参数,改动较大。

在这里插入图片描述

4.2 Prefix Tuning源码分析

PEFT代码主要集中在src/peft目录下,其中tuners目录下实现了PrefixTuning、PromptTuning、PTuning、Adapter、LoRA、AdaLoRA这些方法配置文件的构造、解析,新增训练参数模型的构造,各种PEFT方法配置文件类之间的继承关系如下所示:

在这里插入图片描述

如下图所示,PeftModel类对于不同的任务,有不同的实现类:

在这里插入图片描述

我们下面以peft_model.py文件中PeftModelForCausalLM的forward函数实现为例,看一下在推理阶段如何对于PrefixTuning进行操作:

  • 1、首先通过配置文件的所继承的父类类型来判断PEFT方法是否属于Prompt相关的,如果不是,就表示使用的是Adapter、LoRA等方法,直接执行推理。
    # peft_model.py# PeftModelForCausalLM的forward函数def forward(self,input_ids=None,attention_mask=None,inputs_embeds=None,labels=None,output_attentions=None,output_hidden_states=None,return_dict=None,**kwargs,):peft_config = self.active_peft_config# 如果不是prompt相关算法,直接执行这个分支计算Adapter、LoRA的推理过程if not peft_config.is_prompt_learning:......
  • 2、如果通过配置文件的所继承的父类类型来判断PEFT方法属于Prompt相关的,因为要在序列的开始位置添加虚拟token的embedding,所以也要补全attention mask
      	# 如果是Prompt相关算法,执行下面代码batch_size = _get_batch_size(input_ids, inputs_embeds)if attention_mask is not None:# concat prompt attention mask# 计算添加Prompt之后的attention maskprefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)if kwargs.get("position_ids", None) is not None:warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")kwargs["position_ids"] = Noneif kwargs.get("token_type_ids", None) is not None:warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")kwargs["token_type_ids"] = Nonekwargs.update({"attention_mask": attention_mask,"labels": labels,"output_attentions": output_attentions,"output_hidden_states": output_hidden_states,"return_dict": return_dict,})
  • 3、通过配置文件的类型来判断PEFT方法到底是PrefixTuning/PTuningV2,还是PromptTuning/PTuningV1
    • 如果是PromptTuning/PTuningV1,则将虚拟token的embedding直接concat到原始输入序列的前面,送入base model模型进行推理。
    • 如果是PrefixTuning/PTuningV2,需要给每一个transformer block的key和value添加虚拟token的embedding。
        #  如果为PREFIX_TUNING,需要给每一个transformer block的key和value添加虚拟token的embeddingif peft_config.peft_type == PeftType.PREFIX_TUNING:# Note: 重点关注self.get_prompt(batch_size)这一行代码# self.get_prompt是其父类PeftModel中的方法past_key_values = self.get_prompt(batch_size)return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs)else:# PromptTuning/PTuningV1 分支if inputs_embeds is None:# 计算prompt以外输入内容的embeddinginputs_embeds = self.word_embeddings(input_ids)# concat prompt labelsif labels is not None:prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)# prompt内容的embedding    prompts = self.get_prompt(batch_size=batch_size)prompts = prompts.to(inputs_embeds.dtype)# 将prompt embedding 和原始的embedding 一起送到base model进行推理计算inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
  • 4、这里,我们先重点关注父类PeftModel中的方法self.get_prompt(batch_size)

    • 我们先看父类的初始化方法,会从base_model获取相关模型参数,比如这里获取的模型参数为:

      BloomConfig {"apply_residual_connection_post_layernorm": false,"attention_dropout": 0.0,"bos_token_id": 1,"eos_token_id": 2,"hidden_dropout": 0.0,"hidden_size": 64,         # hidden_size"initializer_range": 0.02,"layer_norm_epsilon": 1e-05,"model_type": "bloom","n_head": 8,               # 8头"n_layer": 2,              # 2个BloomBlock"pretraining_tp": 1,"slow_but_exact": false,"transformers_version": "4.30.1","use_cache": true,"vocab_size": 250880
      }
      
    • 我们这里是prefix tuning,因此会进入self.add_adapter方法。

      # peft/peft_model.py
      class PeftModel(PushToHubMixin, torch.nn.Module):def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"):super().__init__()self.base_model = model# 从base_model中获取相关参数self.config = getattr(self.base_model, "config", {"model_type": "custom"})self.modules_to_save = Noneself.peft_config = {}self.active_adapter = adapter_nameself.peft_type = peft_config.peft_typeif not peft_config.is_prompt_learning:# 如果不是prompt相关方法self.peft_config[adapter_name] = peft_configself.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type](self.base_model, self.peft_config, adapter_name)self.set_additional_trainable_modules(peft_config, adapter_name)else:# 如果是prompt相关方法,prefix tuning会进入此方法self.add_adapter(adapter_name, peft_config)......    
      
    • self.add_adapter方法中最重要的就是_setup_prompt_encoder方法。通过此方法,就会获取PrefixEncoder用来初始化PeftModel类中的属性self.prompt_encoder

          # peft/peft_model.pydef add_adapter(self, adapter_name: str, peft_config: PeftConfig):......try:if peft_config.is_prompt_learning:if hasattr(self.config, "to_dict"):dict_config = self.config.to_dict()else:dict_config = self.configpeft_config = _prepare_prompt_learning_config(peft_config, dict_config)    # 这里会初始化prompt_encoderself._setup_prompt_encoder(adapter_name)......
      
    • 我们看下PrefixEncoder这个类,可以看到 Prefix Tuning 与 P-Tuning v2 最主要的差别就是是否进行重新参数化编码

      # peft/tuners/prefix_tuning.py# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
      # with some refactor
      class PrefixEncoder(torch.nn.Module):def __init__(self, config):super().__init__()self.prefix_projection = config.prefix_projectiontoken_dim = config.token_dimnum_layers = config.num_layersencoder_hidden_size = config.encoder_hidden_sizenum_virtual_tokens = config.num_virtual_tokensif self.prefix_projection and not config.inference_mode:# Use a two-layer MLP to encode the prefix# Prefix Tuning 进行重新参数化编码(通过MLP)self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)self.transform = torch.nn.Sequential(torch.nn.Linear(token_dim, encoder_hidden_size),torch.nn.Tanh(),torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),)else:# P-Tuning v2 self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)def forward(self, prefix: torch.Tensor):if self.prefix_projection:# 先进行Embedding 此时shape为:(batch_size, num_virtual_tokens)# 再进行重新参数化编码,此时shape为:(batch_size, num_virtual_tokens, 2*layers*hidden)prefix_tokens = self.embedding(prefix)past_key_values = self.transform(prefix_tokens)else:past_key_values = self.embedding(prefix)return past_key_values
      
    • 我们继续看self.get_prompt(batch_size)方法

      • 我们这里设置batch_size=1,num_virtual_tokens=10,此时获取到的prompt_encoder为:
      PrefixEncoder((embedding): Embedding(10, 64)(transform): Sequential((0): Linear(in_features=64, out_features=64, bias=True)(1): Tanh()(2): Linear(in_features=64, out_features=256, bias=True))
      )
      
      • 通过下面代码,我们就得到了每一个block,前面需要拼接的past_key_value了。
             def get_prompt(self, batch_size: int):"""Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`."""peft_config = self.active_peft_configprompt_encoder = self.prompt_encoder[self.active_adapter]prompt_tokens = (self.prompt_tokens[self.active_adapter].unsqueeze(0).expand(batch_size, -1).to(prompt_encoder.embedding.weight.device))  # shape = (1, 10)if peft_config.peft_type == PeftType.PREFIX_TUNING:prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]if peft_config.inference_mode:past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)else:# 经过prompt_encoder,shape=(1,10,256)past_key_values = prompt_encoder(prompt_tokens)if self.base_model_torch_dtype is not None:past_key_values = past_key_values.to(self.base_model_torch_dtype)# past_key_values,torch.Size([1, 10, 4, 8, 8])past_key_values = past_key_values.view(batch_size,peft_config.num_virtual_tokens,peft_config.num_layers * 2,peft_config.num_attention_heads,peft_config.token_dim // peft_config.num_attention_heads,)# 如果是编码器-解码器架构,就再复制一次,我们这里不复制if peft_config.num_transformer_submodules == 2:past_key_values = torch.cat([past_key_values, past_key_values], dim=2)      # 重排:torch.Size([4, 1, 8, 10, 8])# 然后split成一个长度为2的tuple(2层bloom block),每个tuple的shape:torch.Size([2, 1, 8, 10, 8])# 也就是说past_key_values是2个层的Prefix embedding,形状为(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])  注意:这里*2是因为key+value.past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(peft_config.num_transformer_submodules * 2)if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]past_key_values = post_process_fn(past_key_values)# 第一个bloom block中key 和 value 如下:    # past_key_values[0][0].shape, torch.Size([8, 8, 10])# past_key_values[0][1].shape, torch.Size([8, 10, 8])return past_key_valueselse:if peft_config.inference_mode:prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)else:prompts = prompt_encoder(prompt_tokens)return prompts
      
  • 5、我们得到每一层的past_key_value后,就需要在每一层的transformer block的key和value前面,拼接past_key_value。我们这里查看BloomModel中BloomBlock中的BloomAttention前向传播代码。

# transformers/models/bloom/modeling_bloom.py
class BloomAttention(nn.Module):......def forward(self,hidden_states: torch.Tensor,   # torch.Size([1, 48, 64])residual: torch.Tensor,        # torch.Size([1, 48, 64])alibi: torch.Tensor,           attention_mask: torch.Tensor,layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,head_mask: Optional[torch.Tensor] = None,use_cache: bool = False,output_attentions: bool = False,):# fused_qkv, torch.Size([1, 48, 192])fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]# 3 x [batch_size, seq_length, num_heads, head_dim]# 3 x torch.Size([1, 48, 8, 8])(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)batch_size, q_length, _, _ = query_layer.shape# query_layer, torch.Size([8, 48, 8])query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)# key_layer, torch.Size([8, 8, 48])key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)# value_layer, torch.Size([8, 48, 8])value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)# 可以看到将past_key和past_value 拼接到key_layer和value_layer前面if layer_past is not None:# past_key, torch.Size([8, 8, 10])# past_value, torch.Size([8, 10, 8])past_key, past_value = layer_past# 拼接后, key_layer,torch.Size([8, 8, 58])#       value_layer,torch.Size([8, 58, 8])key_layer = torch.cat((past_key, key_layer), dim=2)value_layer = torch.cat((past_value, value_layer), dim=1)......    

4.3 Prefix Tuning轻量微调bloom模型

我们现在已经知道,我们只需要在加载原模型后、配置训练器前加peft的代码即可。

from peft import PrefixTuningConfig, get_peft_model, TaskTypeconfig = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10, prefix_projection=True
)config
PrefixTuningConfig(peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=10, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, encoder_hidden_size=None, prefix_projection=True)
model = get_peft_model(model, config)print(model.prompt_encoder)
print(model.print_trainable_parameters())
ModuleDict((default): PrefixEncoder((embedding): Embedding(10, 1024)(transform): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Tanh()(2): Linear(in_features=1024, out_features=49152, bias=True)))
)trainable params: 51,440,640 || all params: 397,209,600 || trainable%: 12.950502706883217
  • 配置训练器、模型训练及推理和2.1一样。
  • 显存消耗情况:
# 设置batch_size=1,并且不进行梯度累加
(base) root@autodl-container-adbc11ae52-f2ebff02:~/autodl-fs/models/langboat/bloom-389m-zh# nvidia-smi 
Mon May 27 17:40:44 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B4:00.0 Off |                  N/A |
| 31%   54C    P2   193W / 250W |   3346MiB / 11264MiB |     51%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

5 Prompt Tuning

5.1 Prompt Tuning概述

  • 论文链接:2104.08691v1 (arxiv.org)

  • Prompt-Tuning的思想: 冻结主模型全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示层,即一个Embedding模块。总的来说就是,只要模型规模够大,简单加入 Prompt tokens 进行微调,就能取得很好的效果。

在这里插入图片描述

  • 该方法可以看作是 Prefix Tuning 的简化版本,只在输入层加入 prompt tokens,并不需要加入 MLP 进行调整来解决难训练的问题,主要在 T5 预训练模型上做实验。

  • 作者做了一系列对比实验,都在说明:随着预训练模型参数的增加,一切的问题都不是问题,最简单的设置也能达到极好的效果

    • Prompt 初始化方式影响(下图a):Random Uniform 方式明显弱于其他两种,但是当模型参数达到一定量级,这种差异也不复存在。

    • Prompt 长度影响(下图b):模型参数达到一定量级时,Prompt 长度为 1 也能达到不错的效果,Prompt 长度为 20 就能达到极好效果。

    • 预训练的方式(下图c):LM Adaptation 的方式效果好,但是当模型达到一定规模,差异又几乎没有了。

    • 微调步数影响(下图d):模型参数较小时,步数越多,效果越好。同样随着模型参数达到一定规模,zero shot 也能取得不错效果。

在这里插入图片描述

5.2 Prompt Tuning轻量微调bloom模型

在peft库中,Prompt又存在两种形式,一种是hard prompt,一种是soft prompt。

5.2.1 soft prompt

我们现在已经知道,我们只需要在加载原模型后、配置训练器前加peft的代码即可。

# 设置batch_size=1,并且不进行梯度累加
# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, low_cpu_mem_usage=True)from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10)model = get_peft_model(model, config)# 打印可训练的参数
model.print_trainable_parameters()# 
trainable params: 10,240 || all params: 345,779,200 || trainable%: 0.0029614274080106613
  • 配置训练器、模型训练及推理和2.1一样。
  • 显存消耗情况:
(base) root@autodl-container-adbc11ae52-f2ebff02:~/autodl-fs/models/langboat/bloom-389m-zh# nvidia-smi 
Mon May 27 17:45:46 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B4:00.0 Off |                  N/A |
| 37%   63C    P2   188W / 250W |   2830MiB / 11264MiB |     45%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

我们可以看下prompt_tuning.py中的源码:

  • 我们发现prompt_tuning_init默认为RANDOM
  • 此时,prompt会nn.Embedding进行随机初始化
class PromptTuningConfig(PromptLearningConfig):# prompt_tuning_init默认为RANDOMprompt_tuning_init: Union[PromptTuningInit, str] = field(default=PromptTuningInit.RANDOM,metadata={"help": "How to initialize the prompt tuning parameters"},)......     class PromptEmbedding(torch.nn.Module):def __init__(self, config, word_embeddings):super().__init__()total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules# 将num_virtual_tokens进行Embedding,即随机初始化(即soft prompt)self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)# 下面代码只有在prompt_tuning_init=TEXT才会执行(即hard prompt)if config.prompt_tuning_init == PromptTuningInit.TEXT:    from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)init_text = config.prompt_tuning_init_text# 对prompt_tuning_init_text进行tokenizerinit_token_ids = tokenizer(init_text)["input_ids"]......# 然后作为embedding的初始化权重word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()word_embedding_weights = word_embedding_weights.to(torch.float32)self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

5.2.2 hard prompt

  • 设置num_virtual_tokens和tokenizer后prompt_tuning_init_text长度相等,此时不会进行截取

  • 设置prompt_tuning_init等于PromptTuningInit.TEXT,此时会对prompt_tuning_init_text进行tokenizer,然后embedding的初始化权重

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, low_cpu_mem_usage=True)from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM,  # 因果语言任务prompt_tuning_init=PromptTuningInit.TEXT,prompt_tuning_init_text="下面是一段人与机器人的对话。",num_virtual_tokens=len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),tokenizer_name_or_path=mode_path)
PromptTuningConfig(peft_type=<PeftType.PROMPT_TUNING: 'PROMPT_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=8, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, prompt_tuning_init=<PromptTuningInit.TEXT: 'TEXT'>, prompt_tuning_init_text='下面是一段人与机器人的对话。', tokenizer_name_or_path='/root/autodl-fs/models/langboat/bloom-389m-zh')
model = get_peft_model(model, config)# 打印可训练的参数
model.print_trainable_parameters()trainable params: 8,192 || all params: 345,777,152 || trainable%: 0.002369155958575308
  • 配置训练器、模型训练及推理和2.1一样。
  • 显存消耗情况:
(base) root@autodl-container-adbc11ae52-f2ebff02:~/autodl-fs/models/langboat/bloom-389m-zh# nvidia-smi 
Mon May 27 17:51:07 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:B4:00.0 Off |                  N/A |
| 38%   64C    P2   205W / 250W |   2834MiB / 11264MiB |     42%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17099.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring+SpringBoot面试总结(近两万字)

SpringSpringBoot面试总结 一、Spring Bean1.1、bean的生命周期&#xff08;对象的创建使用销毁&#xff09;1.1.1、准备工作1.1.2、创建Bean对象1.1.3、注册销毁 1.2、 bean的作用域1.2.1、配置方式 1.3、 spring 自动装配 bean 有哪些方式&#xff08;存疑存疑&#xff09;1.…

软件测试金字塔,对号入座,你在哪层?

自从学习了软件测试,脑袋也清晰了,目标也明确了,就是不知道学到哪里了.中间有很多的困难也有很多成就感,你目前在那个阶段呢? 初级测试工程师 技能要求:需求分析,使用等价类边界值等方法进行用例设计,执行功能测试,发现提交跟踪bug,使用禅道,会在测试中会操作数据库进行检查和…

数学建模--LaTeX的基本使用

目录 1.回顾 2.设置这个页眉和页脚 3.对于字体的相关设置 4.对于这个分级标题的设置 5.列表的使用 6.插入图片 1.回顾 &#xff08;1&#xff09;昨天我们了解到了这个latex的使用基本常识&#xff0c;以及这个宏包的概念&#xff0c;区域的划分&#xff0c;不同的代码代…

电磁仿真--CST综合建模练习1

1. 简介 本文展示一个CST自带的示例&#xff0c;在三维空间中使用带线计算传输线的S参数。基板顶部的带线通过小圆柱连接到底部的短带线&#xff0c;以便绕过可能存在的障碍。 结构生成 该结构完全通过参数输入进行建模&#xff0c;参考波长为10毫米&#xff0c;因此可以轻松…

JavaWeb开发 1.Web开发 介绍

我的生命是一万次的春和景明 —— 24.5.27 一、什么是Web Web&#xff1a; 全球广域网&#xff0c;也称为万维网(www World Wide Web)&#xff0c;能够通过浏览器访问的网站 Web网站的工作流程 学习流程

kafka的安装

windows下kafka的安装 【Kafka】Windows下安装Kafka&#xff08;图文记录详细步骤&#xff09;_windows安装kafka-CSDN博客 kafka生产消息 kafka消费消息

指纹识别经典图书、开源算法库、开源数据库

目录 1. 指纹识别书籍 1.1《精通Visual C指纹模式识别系统算法及实现》 1.2《Handbook of Fingerprint Recognition》 2. 指纹识别开源算法库 2.1 Hands on Fingerprint Recognition with OpenCV and Python 2.2 NIST Biometric Image Software (NBIS) 3. 指纹识别开源数…

【StableDiffusion】SD1.4、1.5、2.0、2.1 和 SDXL0.9-1.0、SDXL turbo 等的区别

总览 1.基础sd base model家族&#xff1a;SD1.4、SD1.5、SD1.5-LCM、SD2.0、SD2.0-768、SD2.1、SD2.1-768、SD2.1-UNCLIP 2.升级sdxl base model家族&#xff1a;SDXL0.9、SDXL1.0、SDXL1.0-LCM、SDXL-DISTILLED、SDXL-TURBO 3.专门用于视频生成的 SVD 家族&#xff1a;SVD、…

C++习题(1)

一、题目描述&#xff1a; 二、代码展示&#xff1a; #include <iostream> #include <iomanip> using namespace std; struct Student{char name[20];int id;int age;float score; }; int main() {int n;cin>>n;Student student[n];float sum0.0;for(int i0…

QQ名片满级会员展示生成HTML源码

源码介绍 QQ名片满级会员展示生成HTML源码&#xff0c;源码由HTMLCSSJS组成&#xff0c;双击html文件可以本地运行效果&#xff0c;也可以上传到服务器里面&#xff0c;保存素材去选择QQ个性名片-选择大图模板-把图上传照片墙即可 源码效果 源码下载 蓝奏云&#xff1a;http…

大数据开发面试题【Mysql篇】

181、mysql数据库中的引擎 用于数据存储、处理和保护数据的核心服务&#xff0c;不同的数据库引擎有其各自的特点&#xff0c;常见的引擎&#xff1a;InnoDB&#xff0c;Mylsam、Memory、Mrg_Mylsam、Blackhole innodb&#xff1a;是一个事务性存储引擎&#xff0c;提供了对事…

Docker基础篇之常用命令

文章目录 1. 帮助启动类命令2. 镜像命令3. 容器命令4. 总结 1. 帮助启动类命令 启动docker&#xff1a; systemctl start docker停止docker&#xff1a; systemctl stop docker重启docker&#xff1a; systemctl restart docker查看docker 的运行状态&#xff1a; systemc…

MER 2024 第二届多模态情感识别挑战赛

多模态情感识别是人工智能领域的一个活跃研究课题。它的主要目标是整合多种模态来识别人类的情绪状态。当前的工作通常为基准数据集假设准确的情感标签&#xff0c;并专注于开发更有效的架构。然而&#xff0c;现有技术难以满足实际应用的需求。 清华大学陶建华教授联合中国科学…

课时138:变量进阶_变量实践_综合案例

2.1.3 综合案例 学习目标 这一节&#xff0c;我们从 免密认证、脚本实践、小结 三个方面来学习 免密认证 案例需求 A 以主机免密码认证 连接到 远程主机B我们要做主机间免密码认证需要做三个动作1、本机生成密钥对2、对端机器使用公钥文件认证3、验证手工演示 本地主机生成…

预热 618,编程好书推荐——提升你的代码力

文章目录 &#x1f4cb;前言&#x1f3af;编程好书推荐&#x1f4d8; Java领域的经典之作&#x1f40d; Python学习者的宝典&#x1f310; 前端开发者的权威指南&#x1f512; 并发编程的艺术&#x1f916; JVM的深入理解&#x1f3d7; 构建自己的编程语言&#x1f9e0; 编程智…

WJ2EDGKA-5.08-8P功能和参数介绍及PDF资料

WJ2EDGKA-5.08-8P 是一款接线端子&#xff0c;以下是它的主要功能和参数介绍&#xff1a; 间距: 5.08mm&#xff08;0.2英寸&#xff09;&#xff0c;这是指相邻针脚之间的中心距离。 针脚数: 8个针脚&#xff08;1X8Pins&#xff09;&#xff0c;这意味着该端子可以连接8根导线…

基于Zynq 7000 SoC的迁移设计

基于Zynq 7000 SoC的迁移设计 Vivado IDE工具使用IP集成器进行嵌入式开发。各种IP Vivado IDE IP目录中提供&#xff0c;以适应复杂的设计。您也可以添加 自定义IP到IP目录。 您可以将基于Zynq 7000平台处理器的设计迁移到Vivado design Suite中 使用以下步骤。 1.生成系统基础…

知攻善防应急响应靶机训练-Web3

前言 本次应急响应靶机采用的是知攻善防实验室的Web-3应急响应靶机 靶机下载地址为&#xff1a; https://pan.quark.cn/s/4b6dffd0c51a 相关账户密码 用户:administrator 密码:xj123456xj123456 解题过程 第一题-攻击者的两个IP地址 直接查看apache的log日志搜索.php 发现…

三维大场景管理-3Dtiles规范

简介 &#xff1a; 这篇文章都是三年前写的了&#xff0c;一直在笔记库存中&#xff0c;今天把他放出来。主要是讲Cesium 的3Dtiles 格式&#xff0c;当然3Dtiles主要是解决场景管理大场景的LOD实现的问题&#xff0c;不管是剔除渲染性能优化之Culling 剔除或者 LOD 、3Dtiles…

SSM基于微信小程序的校园表白墙的设计与实现-计算机毕业设计源码58219

摘 要 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;校园表白墙微信小程序被用户普遍使用&#xff0c;为方便用…