如何在冻结的MSA内部更改q,k,v的形状

在冻结多头自注意力(MSA)层的参数的情况下,若希望更改 q(查询)、k(键)、v(值)的形状,可以通过修改这些矩阵的输出维度或重新排列它们的维度,而不需要改变 MSA 内部的参数或对它们进行反向传播更新。这可以通过以下方式实现:

方法 1:使用视图变换或重排维度

通过重新排列 qkv 的维度,直接改变其形状。这种方法对冻结的参数没有影响,只是在其输出上进行操作。

import torch
import torch.nn as nnclass ModifiedMSA(nn.Module):def __init__(self, attention_layer):super().__init__()self.attention_layer = attention_layer  # 引用原始冻结的 MSA 层def forward(self, x):# 获取 MSA 层的 q、k、v 并更改形状B, N, C = x.shapeqkv = self.attention_layer.qkv(x)  # 原始 qkv 的输出qkv = qkv.reshape(B, N, 3, self.attention_layer.num_heads, C // self.attention_layer.num_heads)q, k, v = qkv.permute(2, 0, 3, 1, 4)  # 分割 q, k, v# 更改 q, k, v 的形状,如扩展到更多维度q = q.view(B, self.attention_layer.num_heads, N, -1)  # 改变查询向量形状k = k.permute(0, 1, 3, 2)  # 例如对键进行转置v = v.view(B, -1, self.attention_layer.num_heads * (C // self.attention_layer.num_heads))# 使用修改后的 q, k, v 计算注意力分数attn = (q @ k) * self.attention_layer.scaleattn = attn.softmax(dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, N, C)out = self.attention_layer.proj(out)  # 使用原始冻结的投影层return out# 将 MSA 层替换为 ModifiedMSA 层
for i, block in enumerate(model.blocks):block.attn = ModifiedMSA(block.attn)

这种方式对 qkv 进行了重排和视图变换,可以有效改变它们的形状,适应不同的计算需求,而不会对原始参数产生影响。

方法 2:插入新的层来处理形状变换

如果需要更灵活的变换,可以在 qkv 后插入新的层,比如 nn.Linear 层,用于扩展或压缩维度。这种方式在 MSA 的输出上添加了一层处理,保持了原始 MSA 参数的冻结状态。

class ExtendedMSA(nn.Module):def __init__(self, attention_layer, new_dim):super().__init__()self.attention_layer = attention_layerself.q_proj = nn.Linear(attention_layer.qkv.out_features // 3, new_dim, bias=False)self.k_proj = nn.Linear(attention_layer.qkv.out_features // 3, new_dim, bias=False)self.v_proj = nn.Linear(attention_layer.qkv.out_features // 3, new_dim, bias=False)def forward(self, x):B, N, C = x.shapeqkv = self.attention_layer.qkv(x)qkv = qkv.reshape(B, N, 3, self.attention_layer.num_heads, C // self.attention_layer.num_heads)q, k, v = qkv.permute(2, 0, 3, 1, 4)  # 分割出 q, k, v# 使用新的层改变 q, k, v 的维度q = self.q_proj(q)k = self.k_proj(k)v = self.v_proj(v)# 使用修改后的 q, k, v 继续 MSA 的注意力计算attn = (q @ k.transpose(-2, -1)) * self.attention_layer.scaleattn = attn.softmax(dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, N, -1)out = self.attention_layer.proj(out)return out# 替换模型中的 MSA 层
for i, block in enumerate(model.blocks):block.attn = ExtendedMSA(block.attn, new_dim=64)  # new_dim 设置为新的维度大小

通过插入新的 Linear 层,可以在不更改原始 MSA 内部参数的情况下扩展或压缩 qkv 的形状。

方法 3:增加动态维度处理

如果希望在不同批次或条件下动态调整 qkv 的形状,可以加入自定义的条件逻辑来动态更改维度。

class DynamicMSA(nn.Module):def __init__(self, attention_layer, dynamic_dim_func):super().__init__()self.attention_layer = attention_layerself.dynamic_dim_func = dynamic_dim_funcdef forward(self, x):B, N, C = x.shapeqkv = self.attention_layer.qkv(x)qkv = qkv.reshape(B, N, 3, self.attention_layer.num_heads, C // self.attention_layer.num_heads)q, k, v = qkv.permute(2, 0, 3, 1, 4)# 动态调整 q, k, v 的形状q = q.view(B, self.attention_layer.num_heads, N, -1)k = k.view(B, self.attention_layer.num_heads, -1, self.dynamic_dim_func(N))v = v.view(B, -1, self.attention_layer.num_heads * self.dynamic_dim_func(C // self.attention_layer.num_heads))# 继续原始注意力计算attn = (q @ k) * self.attention_layer.scaleattn = attn.softmax(dim=-1)out = (attn @ v).transpose(1, 2).reshape(B, N, C)out = self.attention_layer.proj(out)return out# 使用动态维度函数替换 MSA 层
for i, block in enumerate(model.blocks):block.attn = DynamicMSA(block.attn, dynamic_dim_func=lambda dim: dim // 2)  # 例如,将维度缩减一半

这种方式可以根据输入的形状动态调整 qkv 的维度,适应更灵活的场景需求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/56412.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MongoDB】mongodb | 部署 | 常用命令

一、概述 基于mongodb的tcp连接无数据上报,服务器强踢监测。 物联网项目,tcp协议,基于4G卡,设备由于某些原因会断开重连,但是tcp没有断开,导致tcp持续累加,浪费资源。 建立机制: 当t…

解决一个android service启动无法开文件的问题

问题描述 android hal层一般是通过service给系统提供服务的。一般需要将service配置为开机启动。调试阶段,我直接将service push到板卡上,进行调试,未出现问题无法开的问题。在最后集成完成后,放到板卡上,出现启动无法…

【win10】VMware Workstation 16安装win10专业版及安装VMware Tools操作说明

参考链接 VMware虚拟机安装win10系统教程(巨细)_vmware安装win10-CSDN博客https://blog.csdn.net/gdidea/article/details/129523700 win10专业版安装说明 下载win10安装包 百度网盘 链接: https://pan.baidu.com/s/1kf4ORdXYgcqwAz2j86LSZw?pwdk4…

MySQL-数据库的基础操作 o(´^`)o

文本目录: ❄️一、数据库操作: ☑ 1、查看所有的数据库: ☑ 2、创建数据库: ☑ 3、使用数据库: ☑ 4、删除数据库: ❄️二、常用的数据类型: ➷ 1、数值类型: ➷ 2、字符串类型&a…

【2D/3D-Lidar-SLAM】 Cartographer详细解读

【2D/3D-Lidar-SLAM】 Cartographer详细解读 1. 摘要2. Cartographer系统数据处理流程2.1. 数据获取(Input Sensor Data)2.2 姿态外推器(PoseExtrapolator)2.3 局部建图(Local SLAM) 3. 关键模块实现 3.1 局…

【无标题】react组件封装

子组件制作 import { useState,useRef, useEffect} from "react"const Table (data)> {const {value ,option} dataconsole.log(value)const [stata,setValue] useState()const useRefs useRef(value)useEffect(()> {useRefs.current.value value })c…

MyBatis XML映射文件

XML映射文件 XML映射文件的名称与Mapper接口名称一致,并且将XML映射文件和Mapper接口放置在相同包下(同包同名)XML映射文件的namespace属性为Mapper接口全限定名一致XML映射文件中SQL语句的id与Mapper接口中的方法名一致,并保持返…

某知名国企面试题

引言 金九银十,求职热潮再度来袭。最近,有位同学去一家知名国企应聘,回来后带回了一套面试题。这套面试题非常典型,其中包含了许多供应链金融方面的典型问题。这些问题很有分享的价值,大家也可以先自己独立思考一下&a…

Chromium cookies数据存储位置介绍c++

一、cookies数据库存储位置: C:\Users\Administrator\AppData\Local\Chromium\User Data\Default\Network\Cookies 二 、数据库操作类: net\extras\sqlite\sqlite_persistent_cookie_store.cc net\extras\sqlite\sqlite_persistent_cookie_store.h …

C#读取和写入txt文档(在unity中示例)

本篇内容简单介绍如何在c#中内容读取和写入txt文档 注意:先在Unity的StreamingAssets文件夹中创建一个txt文档 一、读取txt 1.1全部一起读取 private void ReadText01() {string filePath Path.Combine(Application.streamingAssetsPath, "testTXT.txt&qu…

[Java基础] 基本数据类型

[Java基础] 运算符 ​​​​​​​[Java基础] Java HashMap 的数据结构和底层原理 目录 Java基本数据类型 byte short int long float double char boolean 存在的一些坑 最佳实践 常见面试题 Java有哪些基本数据类型? 各基本数据类型所占的内存空间…

Spring 和 javaEE的关系

我的理解: 相当于其实只用javaee的规范其实已经可以直接写后端系统了。但是Spring集成扩展了javaee,提供了一套更方便好用的编程规范,可以更高效便捷的写后端系统。 具体介绍: Java EE(现在称为 Jakarta EE&am…

003 Springboot操作RabbitMQ

Springboot整合RabbitMQ 文章目录 Springboot整合RabbitMQ1.pom依赖2.yml配置3.配置队列、交换机方式一:直接通过配置类配置bean方式二:消息监听通过注解配置 4.编写消息监听发送测试5.其他类型交换机配置1.FanoutExchange2.TopicExchange3.HeadersExcha…

AsyncTask的工作原理和缺陷

AsyncTask的工作原理及其缺陷 AsyncTask是Android平台提供的一个轻量级的异步任务类,它允许开发者在后台线程中执行耗时操作,并在操作完成后将结果回调到主线程以更新UI。AsyncTask内部封装了线程池和Handler机制,简化了多线程编程的复杂性。…

4D-fy: Text-to-4D Generation Using Hybrid Score Distillation Sampling技术路线

这篇文章分为四部分,首先从2021年的CLIP说起。 这篇论文的主要工作是提出了一种名为 CLIP(Contrastive Language-Image Pre-training) 的模型,它通过自然语言监督学习视觉模型,以实现视觉任务的零样本(zer…

20 Shell Script输入与输出

标出输入、标准输出、错误输出 一、程序的基本三个IO流 一)文件描述符 ​ 任何程序在Linux系统中都有3个基本的文件描述符 ​ 比如: ​ cd/proc/$$/fd ​ 进入当前shell程序对于内核在文件系统的映射目录中: [rootlocalhost ~]# cd /proc/$$/fd [rootlocalhos…

springcloud之基于RabbitMQ消息总线方式刷新配置服务

前言 在微服务架构中,为了更方便的向微服务实例广播消息,我们通常会构建一个消息中心,让所有的服务实例都连接上来,而该消息中心所发布的消息都会被微服务实例监听和消费,我们把这种机制叫做消息总线(SpringCloud Bus)…

Web集群服务-代理和负载均衡

1. 概述 1. 用户----->代理--->Web节点,后面只有一个节点,一般使用的是nginx代理功能即可 2. 后面如果是集群需要使用nginx负载均衡功能 2. 代理分类 代理分类方向应用正向代理用户(服务器)-->代理--->外部(某网站)服务器通过代理实现共享上网/访问公网反向代理用…

Linux:进程控制(三)——进程程序替换

目录 一、概念 二、使用 1.单进程程序替换 2.多进程程序替换 3.exec接口 4.execle 一、概念 背景 当前进程在运行的时候,所执行的代码来自于自己的源文件。使用fork创建子进程后,子进程执行的程序中代码内容和父进程是相同的,如果子进…

Python基础语法条件

注释 注释的作用 通过用自己熟悉的语言,在程序中对某些代码进行标注说明,这就是注释的作用,能够大大增强程序的可读性。 注释的分类及语法 注释分为两类:单行注释 和 多行注释。 单行注释 只能注释一行内容,语法如下…