Prompt-Tuning源码分析

Prompt-Tuning源码分析

源码

我们这里的代码解析以huggingface peft源码为主
从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

伪代码示例

soft_prompt=torch.nn.Parameter(#Make tensor trainable 
torch.rand(num_tokens,embed_dim))#Initialize soft prompt tensor 
def input_with_softprompt(x,soft_prompt):x=concatenate([soft_prompt,x] #Prepend soft prompt to input dim=seq_len)return x 
model(input_with_softprompt(x))

peft源码

class PromptEmbedding(torch.nn.Module):"""```py>>> from peft import PromptEmbedding, PromptTuningConfig>>> config = PromptTuningConfig(...     peft_type="PROMPT_TUNING",...     task_type="SEQ_2_SEQ_LM",...     num_virtual_tokens=20,...     token_dim=768,...     num_transformer_submodules=1,...     num_attention_heads=12,...     num_layers=12,...     prompt_tuning_init="TEXT",...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",...     tokenizer_name_or_path="t5-base",... )>>> # t5_model.shared is the word embeddings of the base model>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)```Input Shape: (`batch_size`, `total_virtual_tokens`)Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)"""def __init__(self, config, word_embeddings):super().__init__()total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodulesself.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)if config.prompt_tuning_init == PromptTuningInit.TEXT:from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)init_text = config.prompt_tuning_init_textinit_token_ids = tokenizer(init_text)["input_ids"]# Trim or iterate until num_text_tokens matches total_virtual_tokensnum_text_tokens = len(init_token_ids)if num_text_tokens > total_virtual_tokens:init_token_ids = init_token_ids[:total_virtual_tokens]elif num_text_tokens < total_virtual_tokens:num_reps = math.ceil(total_virtual_tokens / num_text_tokens)init_token_ids = init_token_ids * num_repsinit_token_ids = init_token_ids[:total_virtual_tokens]word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()word_embedding_weights = word_embedding_weights.to(torch.float32)self.embedding.weight = torch.nn.Parameter(word_embedding_weights)def forward(self, indices):# Just get embeddingsprompt_embeddings = self.embedding(indices)return prompt_embeddings

输出的模型权重文件如下所示:

/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500]  adapter_config.json
├── [ 33K]  adapter_model.bin
└── [ 111]  README.md0 directories, 3 files

其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

推理

from peft import PeftModel, PeftConfigpeft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")# 模型推理
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3)# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/119044.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【STM32】RCC时钟模块(使用HAL库)

https://gitee.com/linhir-linhir/stm32-f103-c8/blob/master/STM32%E6%9C%80%E6%96%B0%E5%9B%BA%E4%BB%B6%E5%BA%93v3.5/Libraries/STM32F10x_StdPeriph_Driver/inc/stm32f10x_rcc.h STM32最新固件库v3.5/Libraries/CMSIS/CM3/DeviceSupport/ST/STM32F10x/system_stm32f10x.c…

完成比写得好更重要,先完成初稿再说

我发现自己有个毛病&#xff0c;总想着满意了才动手。于是&#xff0c;经常做到一半跑去看文献&#xff0c;然后陷入文献中觉得这个比自己好&#xff0c;那个比自己好。于是&#xff0c;暂时中断手边工作&#xff0c;最后进度被推迟&#xff0c;甚至啥也没做出来。 今晚再次听…

Centos使用tomcat部署jenkins

jenkins的最新版本已经不在支持jdk8&#xff0c;支持的jdk环境如下&#xff1a; 安装jdk环境 yum -y install java-11-openjdk.x86_64 java-11-openjdk-devel.x86_64安装tomcat tomcat官网 cd /optwget https://dlcdn.apache.org/tomcat/tomcat-9/v9.0.82/bin/apache-tomcat…

【项目管理】如何开展高质量的团队管理

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

0026Java程序设计-中学走读生信息管理系统设计与实现

文章目录 摘要**目录**系统设计开发环境 摘要 目前&#xff0c;中学走读生信息管理系统已经发展成为学校的学生走读管理工作中必不可少的一个组成部分&#xff0c;没有该系统&#xff0c;学生的日常工作就会变得繁琐、效率低下。在信息化的社会发展下&#xff0c;有必要建立一…

Spring和SpringMVC,SpringBoot区别的文章

Spring、SpringMVC和SpringBoot是Java开发中常用的三大框架&#xff0c;它们各有特点&#xff0c;适用于不同的开发场景。下面我们将从它们的基本概念、区别和适用场景等方面进行介绍。 Spring框架 Spring是一个轻量级的开源框架&#xff0c;它最初是为了解决企业应用开发的复…

【VPX610】 青翼科技基于6U VPX总线架构的高性能实时信号处理平台

板卡概述 VPX610是一款基于6U VPX架构的高性能实时信号处理平台&#xff0c;该平台采用2片TI的KeyStone系列多核DSP TMS320C6678作为主处理单元&#xff0c;采用1片Xilinx的Virtex-7系列FPGA XC7VX690T作为协处理单元&#xff0c;具有2个FMC子卡接口&#xff0c;各个处理节点之…

普通人做抖店,需要具备什么条件?一篇详解!

我是电商珠珠 抖音小店的热度一直很高&#xff0c;对于想开店的新手来说&#xff0c;不知道需要什么条件&#xff0c;今天我就来给大家详细的讲一下。 一、营业执照 在入驻抖音小店之前&#xff0c;需要准备一张营业执照。 营业执照一共有两种类型&#xff0c;一种为个体工…

成功率高达99%!美国伊利诺伊大学研究人员实现镱量子比特无损测量

研究人员通过无损测量镱-171量子比特实现了实时控制。&#xff08;图片来源&#xff1a;网络&#xff09; 金属镱-171原子可能在自然界中最接近完美量子比特。最近的一项研究展示了如何使用它们来进行重复的量子测量和量子比特自旋&#xff0c;这一研究成果将有助于可扩展量子…

蓝桥云课--1024 第 2 场算法双周赛

2-铺地板【算法赛】&#xff08;找规律&#xff09; 一、题目要求 二、思路 &#xff08;1&#xff09;因为每块地砖都是2*3的规格&#xff1a; 1.n<2或者m<2的时候&#xff0c;则不能使用上述规格的瓷砖 No 2.n<3或者m<3的时候&#xff0c;也不能使用上述规格…

rust重载比较运算符

要重载比较运算符&#xff0c;需要为类型实现对应的trait。 重载和!&#xff0c;需要实现PartialEq或者Eq 重载<、<、> 、 >&#xff0c;需要实现PartialOrd或者Ord 一、Eq/PartialEq 为什么有两个trait呢&#xff1f; 因为相等关系有两种&#xff1a;一种是完全…

30天精通Nodejs--第二天:模块系统与npm

深入了解Node.js&#xff1a;模块系统与npm Node.js作为一款强大的服务器端JavaScript运行环境&#xff0c;模块系统和npm&#xff08;Node Package Manager&#xff09;是其成功的重要组成部分。为我们平时提供了便捷的工具和资源&#xff0c;使得在Node.js平台上构建应用变得…

现在java和大数据选什么?

现在java和大数据选什么&#xff1f; 到底是选择大数据还是JAVA&#xff1f;”相信这个问题困惑着许多转行待定人士和高校专业待选的学生。 在普通人眼里可能会觉得这两个专业或者行业没啥区别&#xff0c;都是IT里的&#xff0c;能有啥大不同。这是第一层。最近很多小伙伴找我…

【Linux】MAC帧协议 + ARP协议

文章目录 &#x1f4d6; 前言1. 数据链路层2. MAC帧格式3. 再谈局域网4. ARP协议4.1 路由器的转发过程&#xff1a;4.2 ARP协议格式&#xff1a; 5. 如何获得目的MAC地址 &#x1f4d6; 前言 在学完网络层IP协议之后&#xff0c;本章我们将继续向下沉一层&#xff0c;进入到数…

深入浅出排序算法之希尔排序

目录 1. 原理 2. 代码实现 3. 性能分析 1. 原理 希尔排序法又称缩小增量法。希尔排序法的基本思想是&#xff1a;先选定一个整数&#xff0c;把待排序文件中所有记录分成个组&#xff0c;所有距离为的记录分在同一组内&#xff0c;并对每一组内的记录进行排序。然后&#xf…

Flink 维表关联

1、实时查询维表 实时查询维表是指用户在 Flink 算子中直接访问外部数据库&#xff0c;比如用 MySQL 来进行关联&#xff0c;这种方式是同步方式&#xff0c;数据保证是最新的。但是&#xff0c;当我们的流计算数据过大&#xff0c;会对外 部系统带来巨大的访问压力&#xff0…

ui设计要学插画吗?优漫动游

现如今很多UI设计培训班都开设了商业插画的课程&#xff0c;有不少同学表示真的要学吗&#xff1f;商业插画都有什么用处呢&#xff1f;今天我们就来给大家介绍一下商业插画在UI设计中的运用。 ui设计要学插画吗&#xff1f;   商业插画属于实用型插画&#xff0c;是一种…

详解预处理(1)

目录 预定义符号 预处理指令#define #define定义符号 #define定义宏 #define替换规则 #和##&#xff08;C语言预处理操作符&#xff09; # ## 带副作用的宏参数 宏和函数的对比 命名约定 在之前我们学习了一个文本文件.c生成一个可执行程序。今天我们详细讲解其中的…

腾讯云国际站服务器端口开放失败怎么办?

腾讯云服务器是腾讯公司推出的一种云服务&#xff0c;用户能够经过这种方式在互联网上进行数据存储和计算。然而&#xff0c;用户在运用腾讯云服务器时或许会遇到各种问题&#xff0c;其间端口敞开失利是一个常见问题。本文将具体介绍如何解决腾讯云服务器端口敞开失利的问题。…