大模型微调---Prompt-tuning微调

目录

    • 一、前言
    • 二、Prompt-tuning实战
      • 2.1、下载模型到本地
      • 2.2、加载模型与数据集
      • 2.3、处理数据
      • 2.4、Prompt-tuning微调
      • 2.5、训练参数配置
      • 2.6、开始训练
    • 三、模型评估
    • 四、完整训练代码

一、前言

Prompt-tuning通过修改输入文本的提示(Prompt)来引导模型生成符合特定任务或情境的输出,而无需对模型的全量参数进行微调。
在这里插入图片描述
Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定,其核心在于将下游任务转化为预训练任务

在这里插入图片描述
新增的 Prompt 内容可以分为 Hard PromptSoft Prompt 两类:

  • Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成;
  • Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务;

Soft Prompt 在 peft 中一般是随机初始化prompt的文本内容,而 Hard prompt 则一般需要设置具体的提示文本内容;

对于不同任务的Prompt的构建示例如下:
在这里插入图片描述

例如,假设我们有兴趣将英语句子翻译成德语。我们可以通过各种不同的方式询问模型,如下图所示。

1)“Translate the English sentence ‘{english_sentence}’ into German: {german_translation}”
2)“English: ‘{english sentence}’ | German: {german translation}”
3)“From English to German:‘{english_sentence}’-> {german_translation}”

上面说明的这个概念被称为硬提示调整

软提示调整(soft prompt tuning)将输入标记的嵌入与可训练张量连接起来,该张量可以通过反向传播进行优化,以提高目标任务的建模性能。

例如下方伪代码:

# 定义可训练的软提示参数
# 假设我们有 num_tokens 个软提示 token,每个 token 的维度为 embed_dim
soft_prompt = torch.nn.Parameter(torch.rand(num_tokens, embed_dim)  # 随机初始化软提示向量
)# 定义一个函数,用于将软提示与原始输入拼接
def input_with_softprompt(x, soft_prompt):# 假设 x 的维度为 (batch_size, seq_len, embed_dim)# soft_prompt 的维度为 (num_tokens, embed_dim)# 将 soft_prompt 在序列维度上与 x 拼接# 拼接后的张量维度为 (batch_size, num_tokens + seq_len, embed_dim)x = concatenate([soft_prompt, x], dim=seq_len)return x# 将包含软提示的输入传入模型
output = model(input_with_softprompt(x, soft_prompt))
  1. 软提示参数:

使用 torch.nn.Parameter 将随机初始化的向量注册为可训练参数。这意味着在训练过程中,soft_prompt 中的参数会随梯度更新而优化。

  1. 拼接输入:

函数 input_with_softprompt 接收原始输入 x(通常是嵌入后的 token 序列)和 soft_prompt 张量。通过 concatenate(伪代码中使用此函数代指张量拼接操作),将软提示向量沿着序列长度维度与输入拼接在一起。

  1. 传递给模型:

将包含软提示的输入张量传给模型,以引导模型在执行特定任务(如分类、生成、QA 等)时更好地利用这些可训练的软提示向量。

二、Prompt-tuning实战

预训练模型与分词模型——Qwen/Qwen2.5-0.5B-Instruct
数据集——lyuricky/alpaca_data_zh_51k

2.1、下载模型到本地

# 下载数据集
dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
ds = load_dataset("./data/alpaca_data", split="train")# 下载分词模型
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
tokenizer.save_pretrained("./local_tokenizer_model")#下载与训练模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径torch_dtype="auto",low_cpu_mem_usage=True,cache_dir="./local_model_cache"  # 指定本地缓存目录
)

2.2、加载模型与数据集

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")# 记载模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",torch_dtype="auto",device_map="cuda:0")

2.3、处理数据

"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")if example["output"] is not None:response = tokenizer(example["output"] + tokenizer.eos_token)else:returninput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)

2.4、Prompt-tuning微调

soft Prompt

# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化

Hard Prompt

# Hard Prompt
prompt = "下面是一段人与机器人的对话。"
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,prompt_tuning_init_text=prompt,num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),tokenizer_name_or_path="../local_tokenizer_model")

加载peft配置

peft_model = get_peft_model(model, config)print(peft_model.print_trainable_parameters())

在这里插入图片描述
可以看到要训练的模型相比较原来的全量模型要少很多

2.5、训练参数配置

# 配置模型参数
args = TrainingArguments(output_dir="chatbot",   # 训练模型的输出目录per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=10,num_train_epochs=1,
)

2.6、开始训练

# 创建训练器
trainer = Trainer(args=args,model=model,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer_model, padding=True )
)
# 开始训练
trainer.train()

可以看到 ,损失有所下降

在这里插入图片描述

三、模型评估

# 模型推理
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipelinemodel = AutoModelForCausalLM.from_pretrained("../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775", low_cpu_mem_usage=True)
peft_model = PeftModel.from_pretrained(model=model, model_id="./chatbot/checkpoint-643")
peft_model = peft_model.cuda()#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")
ipt = tokenizer_model("Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
print(tokenizer_model.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))print("-----------------")
#预训练的管道流
# 构建prompt
ipt = "Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: "
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer_model)
output = pipe(ipt, max_length=256, do_sample=True, truncation=True)
print(output)

训练了一轮,感觉效果不大,可以增加训练轮数试试
在这里插入图片描述

四、完整训练代码


from datasets import load_dataset
from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, \DataCollatorForSeq2Seq, Trainer# 下载数据集
# dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
# ds = load_dataset("./data/alpaca_data", split="train")
# print(ds[0])# 下载分词模型
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
# tokenizer.save_pretrained("./local_tokenizer_model")#下载与训练模型
# model = AutoModelForCausalLM.from_pretrained(
#     pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径
#     torch_dtype="auto",
#     low_cpu_mem_usage=True,
#     cache_dir="./local_model_cache"  # 指定本地缓存目录
# )#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")# 记载模型
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",torch_dtype="auto",device_map="cuda:0")# 处理数据
"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")if example["output"] is not None:response = tokenizer(example["output"] + tokenizer.eos_token)else:returninput_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)prompt = "下面是一段人与机器人的对话。"# prompt-tuning
# Soft Prompt
# config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,prompt_tuning_init_text=prompt,num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),tokenizer_name_or_path="../local_tokenizer_model")peft_model = get_peft_model(model, config)print(peft_model.print_trainable_parameters())# 训练参数args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=8,logging_steps=10,num_train_epochs=1
)# 创建训练器
trainer = Trainer(model=peft_model, args=args, train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_model, padding=True))# 开始训练
trainer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63635.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Edge Scdn用起来怎么样?

Edge Scdn:提升网站安全与性能的最佳选择 在当今互联网高速发展的时代,各种网络攻击层出不穷,特别是针对网站的DDoS攻击威胁,几乎每个行业都可能成为目标。为了确保网站的安全性与稳定性,越来越多的企业开始关注Edge …

通信技术以及5G和AI保障电网安全与网络安全

摘 要:电网安全是电力的基础,随着智能电网的快速发展,越来越多的ICT信息通信技术被应用到电力网络。本文分析了历史上一些重大电网安全与网络安全事故,介绍了电网安全与网络安全、通信技术与电网安全的关系以及相应的电网安全标准…

批量提取zotero的论文构建知识库做问答的大模型(可选)——含转存PDF-分割统计PDF等

文章目录 提取zotero的PDF上传到AI平台保留文件名代码分成20个PDF视频讲解 提取zotero的PDF 右键查看目录 发现目录为 C:\Users\89735\Zotero\storage 写代码: 扫描路径‘C:\Users\89735\Zotero\storage’下面的所有PDF文件,全部复制一份汇总到"C:\Users\89735\Downl…

精准采集整车信号:风丘混合动力汽车工况测试

一 背景 混合动力汽车是介于纯电动汽车与燃油汽车两者之间的一种新能源汽车。它既包含纯电动汽车无污染、启动快的优势,又拥有燃油车续航便捷、不受电池容量限制的特点。在当前环境下,混合动力汽车比纯电动汽车更符合目前的市场需求。 然而&#xff0c…

带标题和不带标题的内部表

什么是工作区? 什么是工作区?简单来说,工作区是单行数据。它们应具有与任何内部表相同的格式。它用于一次处理一行内部表中的数据。 内表和工作区的区别 ? 一图胜千言 内表的类型 有两种类型的内表: 带 Header 行…

【图像分类实用脚本】数据可视化以及高数量类别截断

图像分类时,如果某个类别或者某些类别的数量远大于其他类别的话,模型在计算的时候,更倾向于拟合数量更多的类别;因此,观察类别数量以及对数据量多的类别进行截断是很有必要的。 1.准备数据 数据的格式为图像分类数据集…

React系列(八)——React进阶知识点拓展

前言 在之前的学习中,我们已经知道了React组件的定义和使用,路由配置,组件通信等其他方法的React知识点,那么本篇文章将针对React的一些进阶知识点以及React16.8之后的一些新特性进行讲解。希望对各位有所帮助。 一、setState &am…

PCIe_Host驱动分析_地址映射

往期内容 本文章相关专栏往期内容,PCI/PCIe子系统专栏: 嵌入式系统的内存访问和总线通信机制解析、PCI/PCIe引入 深入解析非桥PCI设备的访问和配置方法 PCI桥设备的访问方法、软件角度讲解PCIe设备的硬件结构 深入解析PCIe设备事务层与配置过程 PCIe的三…

【阅读记录-章节6】Build a Large Language Model (From Scratch)

文章目录 6. Fine-tuning for classification6.1 Different categories of fine-tuning6.2 Preparing the dataset第一步:下载并解压数据集第二步:检查类别标签分布第三步:创建平衡数据集第四步:数据集拆分 6.3 Creating data loa…

梳理你的思路(从OOP到架构设计)_简介设计模式

目录 1、 模式(Pattern) 是较大的结构​编辑 2、 结构形式愈大 通用性愈小​编辑 3、 从EIT造形 组合出设计模式 1、 模式(Pattern) 是较大的结构 组合与创新 達芬奇說:簡單是複雜的終極形式 (Simplicity is the ultimate form of sophistication) —Leonardo d…

【libuv】Fargo信令2:【深入】client为什么收不到服务端响应的ack消息

客户端处理server的ack回复,判断链接连接建立 【Fargo】28:字节序列【libuv】Fargo信令1:client发connect消息给到server客户端启动后理解监听read消息 但是,这个代码似乎没有触发ack消息的接收: // 客户端初始化 void start_client(uv_loop_t

Python-基于Pygame的小游戏(贪吃蛇)(一)

前言:贪吃蛇是一款经典的电子游戏,最早可以追溯到1976年的街机游戏Blockade。随着诺基亚手机的普及,贪吃蛇游戏在1990年代变得广为人知。它是一款休闲益智类游戏,适合所有年龄段的玩家,其最初为单机模式,后来随着技术发…

使用k6进行MongoDB负载测试

1.安装环境 安装xk6-mongo扩展 ./xk6 build --with github.com/itsparser/xk6-mongo 2.安装MongoDB 参考Docker安装MongoDB服务-CSDN博客 连接成功后新建test数据库和sample集合 3.编写脚本 test_mongo.js import xk6_mongo from k6/x/mongo;const client xk6_mongo.new…

2024 年最新前端ES-Module模块化、webpack打包工具详细教程(更新中)

模块化概述 什么是模块?模块是一个封装了特定功能的代码块,可以独立开发、测试和维护。模块通过导出(export)和导入(import)与其他模块通信,保持内部细节的封装。 前端 JavaScript 模块化是指…

最小堆及添加元素操作

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 最小堆及添加元素操作 [太阳]选择题 以下代码执行的结果为? import heapq heap [] heapq.heappush(heap, 5) heapq.heappush(heap, 3) heapq.heappush(heap, 2) heapq.…

【计算机网络】期末考试预习复习|中

作业讲解 转发器、网桥、路由器和网关(4-6) 作为中间设备,转发器、网桥、路由器和网关有何区别? (1) 物理层使用的中间设备叫做转发器(repeater)。 (2) 数据链路层使用的中间设备叫做网桥或桥接器(bridge)。 (3) 网络层使用的中间设备叫做路…

前端工程化-Vue脚手架安装

在现代前端开发中,Vue.js已成为一个流行的框架,而Vue CLI(脚手架)则为开发者提供了一个方便的工具,用于快速创建和管理Vue项目。本文将详细介绍如何安装Vue脚手架,创建新项目以及常见问题的解决方法。 什么…

罗德与施瓦茨ZN-Z129E网络分析仪校准套件具体参数

罗德与施瓦茨ZN-Z129E网络校准件ZN-Z129E网络分析仪校准套件 1,频率范围从9kHz到4GHz(ZNB4),8.5GHz(ZNB8),20GHz(ZNB20),40GHz(ZNB40) 2,动态范围宽,高达140 dB 3,扫描时间短达4ms…

如何为IntelliJ IDEA配置JVM参数

在使用IntelliJ IDEA进行Java开发时,合理配置JVM参数对于优化项目性能和资源管理至关重要。IntelliJ IDEA提供了两种方便的方式来设置JVM参数,以确保你的应用程序能够在最佳状态下运行。本文将详细介绍这两种方法:通过工具栏编辑配置和通过服…

unity is running as administrator 管理员权限问题

每次打开工程弹出unity is running as administrator的窗口 unity版本2022.3.34f1,电脑系统是win 11系统解决方法一:解决方法二: unity版本2022.3.34f1,电脑系统是win 11系统 每次打开工程都会出现unity is running as administr…