Prompt-Tuning 提示词微调

1. Hard Prompt

定义: Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务.

原理: Prompt Tuning原理如下图所示:冻结主模型全部参数,在训练数据前加入一小段Prompt,只训练Prompt的表示层,即一个Embedding模块。论文实验表明,只要模型规模够大,简单加入 Prompt tokens 进行微调,就能取得很好的效果。

优点: 资源消耗比 Soft Prompt 小,容易收敛。

2. Soft Prompt

**定义:**Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成。

缺点: 资源消耗大,参数都是随机初始化的,Loss 会有震荡。

在这里插入图片描述

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit, PeftModel# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}if __name__ == "__main__":# 加载数据集dataset = load_from_disk("./PEFT/data/alpaca_data_zh")# 处理数据tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)# print(tokenizer.decode(tokenized_ds[1]["input_ids"]))# print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))# 创建模型model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)# 设置 Prompt-Tuning# Soft Prompt# config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化# Hard Promptconfig = PromptTuningConfig(task_type = TaskType.CAUSAL_LM,prompt_tuning_init = PromptTuningInit.TEXT,prompt_tuning_init_text = "下面是一段人与机器人的对话。", # 设置hard_prompt的具体内容num_virtual_tokens = len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),tokenizer_name_or_path = "Langboat/bloom-1b4-zh")model = get_peft_model(model, config) # 生成Prompt-Tuning对应的modelprint(model.print_trainable_parameters())# 训练参数args = TrainingArguments(output_dir = "/tmp_1203",per_device_train_batch_size = 1,gradient_accumulation_steps = 8,logging_steps = 10,num_train_epochs = 1)# trainertrainer = Trainer(model = model,args = args,train_dataset = tokenized_ds,data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, padding = True))# 训练模型trainer.train()# 模型推理model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)peft_model = PeftModel.from_pretrained(model = model, model_id = "/tmp_1203/checkpoint-500/")peft_model = peft_model.cuda()ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/76366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Vue生命周期的演变:从Vue 2到Vue 3的深度剖析】

Vue生命周期的演变:从Vue 2到Vue 3的深度剖析 1. 生命周期钩子的概念与意义 Vue框架通过生命周期钩子函数使开发者可以在组件不同阶段执行自定义逻辑。这些钩子函数是Vue组件生命周期中的关键切入点,对于控制组件行为至关重要。 2. Vue 2中的生命周期…

java ai 图像处理

Java AI 图像处理 图像处理是人工智能(AI)领域中非常重要的一个应用方向。通过使用Java编程语言和相应的库,我们可以实现各种图像处理任务,如图像识别、图像分类、图像分割等。本文将介绍一些常见的图像处理算法,并通过…

从 0~1 保姆级 详细版 PostgreSQL 数据库安装教程

PostgreSQL数据库安装 PostgreSQL官网 【PostgreSQL官网】 | 【PostgreSQL安装官网_Windows】 安装步骤 step1: 选择与电脑相对应的PostgreSQL版本进行下载。 step2: 双击打开刚才下载好的文件。 step3: 在弹出的setup窗口中点击 …

Keil MDK中禁用半主机(No Semihosting)

在 ARM 编译器(如 Keil MDK) 中禁用半主机(Semihosting)并实现标准库的基本功能,需要以下步骤: 1. 禁用半主机 #pragma import(__use_no_semihosting) // 禁用半主机模式作用:防止标准库函数&…

github | 仓库权限管理 | 开权限

省流版总结: github 给别人开权限:仓库 -> Setting -> Cllaborate -> Add people GitHub中 将公开仓库改为私有:仓库 -> Setting -> Danger Zone(危险区) ->Change repository visibility( 更改仓…

快速部署大模型 Openwebui + Ollama + deepSeek-R1模型

背景 本文主要快速部署一个带有web可交互界面的大模型的应用,主要用于开发测试节点,其中涉及到的三个组件为 open-webui Ollama deepSeek开放平台 首先 Ollama 是一个开源的本地化大模型部署工具,提供与OpenAI兼容的Api接口,可以快速的运…

极狐GitLab 项目导入导出设置介绍?

极狐GitLab 是 GitLab 在中国的发行版,关于中文参考文档和资料有: 极狐GitLab 中文文档极狐GitLab 中文论坛极狐GitLab 官网 导入导出设置 (BASIC SELF) 导入和导出相关功能的设置。 配置允许的导入源 在从其他系统导入项目之前,必须为该…

信奥还能考吗?未来三年科技特长生政策变化

近年来,科技特长生已成为名校录取的“黄金敲门砖”。 从CSP-J/S到NOI,编程竞赛成绩直接关联升学优势。 未来三年,政策将如何调整?家长该如何提前布局? 一、科技特长生政策趋势:2025-2027关键变化 1. 竞…

AI测试用例生成平台

AI测试用例生成平台 项目背景技术栈业务描述项目展示项目重难点 项目背景 针对传统接口测试用例设计高度依赖人工经验、重复工作量大、覆盖场景有限等行业痛点,基于大语言模型技术实现接口测试用例智能生成系统。 技术栈 LangChain框架GLM-4模型Prompt Engineeri…

操作系统-PV

🧠 背景:为什么会有 PV? 类比:内存(生产者) 和 CPU(消费者) 内存 / IO / 磁盘 / 网络下载 → 不断“生产数据” 例如:读取文件、下载视频、从数据库加载信息 CPU → 负…

工厂方法模式详解及在自动驾驶场景代码示例(c++代码实现)

模式定义 工厂方法模式(Factory Method Pattern)是一种创建型设计模式,通过定义抽象工厂接口将对象创建过程延迟到子类实现,实现对象创建与使用的解耦。该模式特别适合需要动态扩展产品类型的场景。 自动驾驶感知场景分析 自动驾…

基于 S2SH 架构的企业车辆管理系统:设计、实现与应用

在企业运营中,车辆管理是一项重要工作。随着企业规模的扩大,车辆数量增多,传统管理方式效率低下,难以满足企业需求。本文介绍的基于 S2SH 的企业车辆管理系统,借助现代化计算机技术,实现车辆、驾驶员和出车…

IntelliJ IDEA download JDK

IntelliJ IDEA download JDK 自动下载各个版本JDK,步骤 File - Project Structure (快捷键 Ctrl Shift Alt S) 如果下载失败,换个下载站点吧。一般选择Oracle版本,因为java被Oracle收购了 好了。 花里胡哨&#…

MCP协议在纳米材料领域的深度应用:从跨尺度协同到智能研发范式重构

MCP协议在纳米材料领域的深度应用:从跨尺度协同到智能研发范式重构 文章目录 MCP协议在纳米材料领域的深度应用:从跨尺度协同到智能研发范式重构一、MCP协议的技术演进与纳米材料研究的适配性分析1.1 MCP协议的核心架构升级1.2 纳米材料研发的核心挑战与…

OpenAI发布GPT-4.1:开发者专属模型的深度解析 [特殊字符]

最近OpenAI发布了GPT-4.1模型,却让不少人感到困惑。今天我们就来深入剖析这个新模型的关键信息! 重要前提:API专属模型 💻 首先需要明确的是,GPT-4.1仅通过API提供,不会出现在聊天界面中。这是因为该模型主…

DemoGen:用于数据高效视觉运动策略学习的合成演示生成

25年2月来自清华、上海姚期智研究院和上海AI实验室的论文“DemoGen: Synthetic Demonstration Generation for Data-Efficient Visuomotor Policy Learning”。 视觉运动策略在机器人操控中展现出巨大潜力,但通常需要大量人工采集的数据才能有效执行。驱动高数据需…

界面控件DevExpress WPF v25.1新功能预览 - 文档处理类功能升级

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

Muduo网络库实现 [十六] - HttpServer模块

目录 设计思路 类的设计 模块的实现 公有接口 私有接口 疑问点 设计思路 本模块就是设计一个HttpServer模块,提供便携的搭建http协议的服务器的方法。那么这个模块需要如何设计呢? 这还需要从Http请求说起。 首先从http请求的请求行开始分析&…

多模态记忆融合:基于LSTM的连续场景生成——突破AI视频生成长度限制

一、技术背景与核心挑战 2025年视频生成领域面临的关键难题是长时程连贯性——传统方法在生成超过5分钟视频时会出现场景跳变、物理规则不一致等问题。本研究提出时空记忆融合架构(ST-MFA),通过LSTM记忆门控与多模态对齐技术,在R…

LabVIEW油气井井下集成监测系统

LabVIEW平台开发的油气井井下集成监测系统通过实时监控油气井的井下环境参数,如温度、压力和有害气体含量,有效提高了油气采收率并确保了作业安全。系统利用高精度传感器和强大的数据处理能力,通过综合监测和分析,实现了对油气井环…