Prefix-Tuning源码解析

Prefix-Tuning源码解析

Prefix-Tuning在PEFT包中的源码实现
改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py

import torch
from transformers import PretrainedConfigclass PrefixEncoder(torch.nn.Module):r'''The torch.nn model to encode the prefixInput shape: (batch-size, prefix-length)Output shape: (batch-size, prefix-length, 2*layers*hidden)'''def __init__(self, config):super().__init__()self.prefix_projection = config.prefix_projectionif self.prefix_projection:# Use a two-layer MLP to encode the prefixself.embedding = torch.nn.Embedding(config.prefix_length, config.hidden_size)self.trans = torch.nn.Sequential(torch.nn.Linear(config.hidden_size, config.encoder_hidden_size),torch.nn.Tanh(),torch.nn.Linear(config.encoder_hidden_size, config.num_hidden_layers * 2 * config.hidden_size))else:self.embedding = torch.nn.Embedding(config.prefix_length, config.num_hidden_layers * 2 * config.hidden_size)def forward(self, prefix: torch.Tensor):if self.prefix_projection:prefix_tokens = self.embedding(prefix)past_key_values = self.trans(prefix_tokens)else:past_key_values = self.embedding(prefix)return past_key_valuesif __name__ == "__main__":configs = {"prefix_length":20,"hidden_size":768,"encoder_hidden_size":768,"num_hidden_layers":12,"prefix_projection":False}prefix_encoder = PrefixEncoder(config=PretrainedConfig.from_dict(configs))print(prefix_encoder)batch_size = 8prefix = torch.arange(20).long().expand(batch_size, -1)print(prefix.shape)output = prefix_encoder(prefix)print(output.shape)

下面我们以T5-large模型为例子:
不考虑Use a two-layer MLP to encode the prefix的话,prefix tuning主要包括以下代码:

class PrefixEncoder(torch.nn.Module):def __init__(self, config):super().__init__()...self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) #num_virtual_tokens=20,token_dim=1024,num_layers=24def forward(self, prefix: torch.Tensor):past_key_values = self.embedding(prefix)return past_key_values

得到的PrefixEncoder被传入peft->peft_model.py->prompt_encoder

PrefixEncoder((embedding): Embedding(20, 49152) # 1024*2*24
)

self.prompt_tokens初始化为长度2*20的向量,因为T5有编码器和解码器,需要两次prefix:

self.prompt_tokens[adapter_name] = torch.arange(config.num_virtual_tokens * config.num_transformer_submodules).long() #20*2# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
#        36, 37, 38, 39])
prompt_tokens = (self.prompt_tokens[self.active_adapter].unsqueeze(0).expand(batch_size, -1).to(prompt_encoder.embedding.weight.device)) 
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
# 此时prompt_tokens.shape = (batch_size=8, num_virtual_tokens=20)past_key_values = prompt_encoder(prompt_tokens)
torch.Size([8, 20, 49152])

但目前的past_key_values还是所有层的集合,我们需要把past_key_values分解为每一层:

past_key_values = past_key_values.view(batch_size, #8peft_config.num_virtual_tokens, #20peft_config.num_layers * 2, #24*2peft_config.num_attention_heads, #16peft_config.token_dim // peft_config.num_attention_heads, #1024/16)
# torch.Size([8, 20, 48, 16, 64])

因为有编码器和解码器,所以再复制一次

past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
# torch.Size([8, 20, 96, 16, 64])# 重排:torch.Size([96, 8, 16, 20, 64])
# 然后split成一个长度为24的tuple,每个tuple的shape:torch.Size([4, 8, 16, 20, 64])
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(peft_config.num_transformer_submodules * 2)

也就是说past_key_values是24个层的Prefix embedding,形状为`(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])

注意这里*2是因为key+value.

transformers->models->t5->modeling_t5.py->T5Attention类,这里的关键步骤是project函数中的hidden_states = torch.cat([past_key_value, hidden_states], dim=2),注意project函数仅仅用于key和value。

def forward(self,hidden_states,mask=None,key_value_states=None,position_bias=None,past_key_value=None,layer_head_mask=None,query_length=None,use_cache=False,output_attentions=False,):"""Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states)."""# Input is (batch_size, seq_length, dim)# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)batch_size, seq_length = hidden_states.shape[:2]real_seq_length = seq_lengthif past_key_value is not None:if len(past_key_value) != 2:raise ValueError(f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states")real_seq_length += past_key_value[0].shape[2] if query_length is None else query_lengthkey_length = real_seq_length if key_value_states is None else key_value_states.shape[1]def shape(states):"""projection"""return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)def unshape(states):"""reshape"""return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)def project(hidden_states, proj_layer, key_value_states, past_key_value):"""projects hidden states correctly to key/query states"""if key_value_states is None:# self-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(hidden_states))elif past_key_value is None:# cross-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(key_value_states))if past_key_value is not None:if key_value_states is None:# self-attn# (batch_size, n_heads, key_length, dim_per_head)# 注意这里是重点:用串联方式hidden_states = torch.cat([past_key_value, hidden_states], dim=2)elif past_key_value.shape[2] != key_value_states.shape[1]:# checking that the `	sequence_length` of the `past_key_value` is the same as# the provided `key_value_states` to support prefix tuning# cross-attn# (batch_size, n_heads, seq_length, dim_per_head)hidden_states = shape(proj_layer(key_value_states))else:# cross-attnhidden_states = past_key_valuereturn hidden_statesreal_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

分别计算query_states、key_states、value_states,用query和key计算attention score,得到score形状为torch.Size([8, 16, 2, 22]),所以输入X可以attend to itself以及prefix。

    # hidden_states shape: torch.Size([8, 2, 1024])   # get query statesquery_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head) # query_states shape: torch.Size([8, 16, 2, 64])# get key/value stateskey_states = project(hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None)# key_states shape: torch.Size([8, 16, 22, 64])value_states = project(hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None)# value_states shape: torch.Size([8, 16, 22, 64])# compute scores# torch.Size([8, 16, 2, 22])scores = torch.matmul(query_states, key_states.transpose(3, 2))  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

接下来就是经典的attention操作了。用attn_weights ([8, 16, 2, 22]) 和value_states ([8, 16, 22, 64])相乘,把22消掉,就是每个输入X的输出了。

# if key and values are already calculated
# we want only the last query position bias
# position_bias.shape: torch.Size([8, 16, 2, 22])scores += position_bias_maskedattn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)  # (batch_size, n_heads, seq_length, key_length)attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)  # (batch_size, n_heads, seq_length, key_length)attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim) torch.Size([8, 2, 1024])attn_output = self.o(attn_output)present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else Noneoutputs = (attn_output,) + (present_key_value_state,) + (position_bias,)if output_attentions:outputs = outputs + (attn_weights,)return outputs

参考

https://huggingface.co/docs/peft/task_guides/seq2seq-prefix-tuning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/111447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux-Jconsole连接远程服务器

Jconsole连接远程服务器 一、修改jmxremote.password.template文件二、启动jar项目三、jconsole远程连接1、打开的你jconsole2、远程连接 一、修改jmxremote.password.template文件 进去你的/idk/jre/lib/management目录下可以看到jmxremote.password.template文件 修改jmxr…

使用apose.pdf批量导出图片

今天遇到了,需要将pdf文件插到word里,好像word不支持直接插入pdf文件,所以现在通过将pdf转为图片的方式,逐个将图片插入到word。这里使用apose.pdf第三方库,将多个pdf文件读取,然后转为pdf。具体的实现代码…

sqoop 脚本密码管理

1:背景 生产上很多sqoop脚本的密码都是铭文,很不安全,找了一些帖子,自己尝试了下,记录下细节,使用的方式是将密码存在hdfs上然后在脚本里用别名来替代。 2:正文 第一步:创建密码对…

【网络空间实战攻防能力训练】DNS欺骗

DNS欺骗 0x01 环境准备0x02 实验过程1.设置Kali Linux主机、Windows Server 2016服务器与Windows 10在同一个可以上网的网段。分别记录各个主机的IP地址,并检查他们之间能否ping通。配置Windows Server 2016打开其IIS的Web服务。2.在Kali Linux的root终端中打开并编辑ettercap…

中文编程开发语言工具编程实际案例:美发店会员管理系统软件编程实例

中文编程开发语言工具编程实际案例:美发店会员管理系统软件编程实例 中文编程开发语言工具编程实际案例:美发店会员管理系统软件编程实例。 软件功能: 1、系统设置:参数设定,账号及权限设置,系统初始化&a…

minikube创建一个pod并暴露端口(使用docker驱动安装)

因为minikube使用service暴露端口是使用nodeIP:nodePort 而不是 localhost:nodePort 公开访问。我们只能使用kubectl的端口转发功能或者使用iptables的转发功能来实现外网服务暴露。 我这里使用shiro来举例 apiVersion: apps/v1 kind: Deployment metadata:name: shiro550 spe…

人工智能(pytorch)搭建模型20-基于pytorch搭建文本生成视频的生成对抗网络,技术创新点介绍

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型20-基于pytorch搭建文本生成视频的生成对抗网络,技术创新点介绍,随着人工智能和深度学习技术的飞速发展,文本到视频生成已经成为计算机视觉领域中一个重…

docker mysql 5.7

1.docker 安装mysql 5.7 docker pull mysql:5.72.配置容器MySQL数据、配置、日志挂载宿主机目录 # 宿主机创建数据存放目录映射到容器 mkdir -p /usr/local/docker_data/mysql/data# 宿主机创建配置文件目录映射到容器 mkdir -p /usr/local/docker_data/mysql/conf #(需要在…

Golang协程的概念、用法、场景及案例

在当今的软件开发领域中,高性能和并发性是很重要的。开发人员需要编写能够有效利用多核处理器的程序,以提高应用程序的性能和响应能力。Go语言(Golang)就是一种在这方面非常强大的编程语言,它提供了一种称为协程&#…

CentOS(5)——rpm包和源码包区别

目录 一、简介 二、区别 ①包名称 ②概念 ③优缺点 ④安装位置的区别 ⑤安装位置不同带来的影响 ⑥卸载方式的不同 一、简介 最近在公司内网离线升级Git时,遇见两个概念,分别是使用rpm包安装git,另一个这是编译源码包安装git&#x…

Spring MVC(上)

1、Spring MVC简介: MVC是一种软件架构的思想,将软件按照模型、视图、控制器来划分 M:Model,模型层,指工程中的JavaBean,作用是处理数据 JavaBean分为两类: 一类称为实体类Bean:专…

语法分析出错,不是 GROUP BY 表达式

报错 ### Cause: dm.jdbc.driver.DMException: 第 9 行, 第 69 列[30]附近出现错误: 语法分析出错 ; bad SQL grammar []; nested exception is dm.jdbc.driver.DMException: 第 9 行, 第 69 列[30]附近出现错误: 语法分析出错at org.springframework.jdbc.support.SQLState…

Java后端开发(五)-- 对象转换工具类

为避免返回给前端的字段信息太多,在缓解前、后端通信的带宽压力的前提下,对不必要的字段的信息进行不返回时,entity层对象需要向vo层对象进行转换,同事尽量减少geetter与setter方法的编码。 1. ConvertUtils工具类 import org.slf4j.Logger; import org.slf4j.LoggerFacto…

【Godot引擎开发】简单基础,外加一个小游戏DEMO

博主:_LJaXi 专栏: Godot | 横版游戏开发 Godot 物体规律移动内置虚函数浮点计算浮点数计算数组APIInput单例与自定义单例节点NodeSprite2DArea2DCollisionShape2DKinematicBody2DRigidBody2D Pong游戏场景安排玩家1玩家2小球记分系统文件概要 下面是介绍…

【C++】C++学习(模板+排序+测时)

本文主要记录使用模板函数来编写排序算法,并计算运行时间。 模板函数(Template Function)是一种通用函数,可以在其定义时不指定具体的参数类型,在调用时再根据需要指定具体类型。模板函数可以接受不同类型的参数&…

『力扣刷题本』:相交链表

咳咳,实在抱歉,刚开始心气太高了,叫『每日一题』,我是真的坚持不下了。 经过这次打击,我算是摸明白自己在写博客这件事情上几斤几两了,现在预计一周两更,再慢慢把更新频率提上来。 正在努力补…

DRM中render-node编号的分配

DRM系统 DRM是direct rendering manager的简称。DRM是linux kernel中与负责video cards功能的GPU打交道的子系统。DRM给出了一组API,可以供用户程序来发送命令和数据给GPU设备从而来控制比如display、render等功能。 render-node由来 在以前,DRM子系统…

Java数字处理类-- Math类--数学运算

在Java中提供了一个执行数学基本运算的Math类,该类包括了常用的数学运算方法和常量,包括【三角函数方法】,【指数函数方法】,【取整函数方法】、【取最大值函数方法】、【取最小值函数方法】、【取平均值函数方法】、【对数函数方法】&#x…

MSF入门

漏洞数据库、利用工具集MSF MSF: The Metasploit Framework 简称美少妇 MSF安装 安装平台 Kali Linux: 自带Linux: 阿里云CentOS7安装msfWindows.msi 使用方式 msfconsole 交互终端 msfcli . msfconsole -x"command:..."图形界面: artimate、viper 下面我会用…

ES6 Class和Class继承

1.class的基本语法 class可以理解为是一个语法糖,将js只能通过构造函数创建实例的方法进行了补充 构造函数: function Person ({ name, age18 }) {this.name namethis.age age } new Person({name: 张三}) Class类: class Person {con…