transfomer中Multi-Head Attention的源码实现

简介

Multi-Head Attention是一种注意力机制,是transfomer的核心机制.
在这里插入图片描述
Multi-Head Attention的原理是通过将模型分为多个头,形成多个子空间,让模型关注不同方面的信息。每个头独立进行注意力运算,得到一个注意力权重矩阵。输出的结果再通过线性变换和拼接操作组合在一起。这样可以提高模型的表示能力和泛化性能。
在Multi-Head Attention中,每个头的权重矩阵是随机初始化生成的,并在训练过程中通过梯度下降等优化算法进行更新。通过这种方式,模型可以学习到如何将输入序列的不同部分关联起来,从而捕获更多的上下文信息。
总之,Multi-Head Attention通过将模型分为多个头,形成多个子空间,让模型关注不同方面的信息,提高了模型的表示能力和泛化性能。它的源码实现基于Scaled Dot-Product Attention,通过并行运算和组合输出来实现多头注意力机制。

源码实现:

具体源码及其注释如下,配好环境可直接运行:

import torch
from torch import nnclass MultiheadAttention(nn.Module):def __init__(self,embed_dim,num_heads,att_dropout=0.1,out_dropout=0.1,average_attn_weights=True):super(MultiheadAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.att_dropout = nn.Dropout(att_dropout)self.out_dropout = nn.Dropout(out_dropout)self.average_attn_weights = average_attn_weightsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** 0.5assert self.embed_dim == self.num_heads * self.head_dim, \'embed_dim <{}> must be divisible by num_heads <{}>'.format(self.embed_dim, self.num_heads)self.fuse_heads = nn.Linear(self.embed_dim, self.embed_dim)def forward(self,query: torch.Tensor,key: torch.Tensor,value: torch.Tensor,identity=None,query_pos=None,key_pos=None):assert query.dim() == 3 and key.dim() == 3 and value.dim() == 3assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"tgt_len, bsz, embed_dim = query.shape  # [查询数量 batch数量 特征维度]src_len, _, _ = key.shape  # [被查询数量,_,_]# 默认和query进行shortcut(要在位置编码前,因为output为输出特征,特征和原特征shortcut,下一层再重新加位置编码,否则不就重了)if identity is None:identity = query# 位置编码if query_pos is not None:query = query + query_posif key_pos is not None:key = key + key_pos# 特征划分为self.num_heads 份 [tgt,b,embed_dim] -> [b,n_h, tgt, d_h]# [n,b,n_h*d_h] -> [b,n_h,n,d_h] 主要是target和source之前的特征匹配和提取, batch和n_h维度不处理query = query.contiguous().view(tgt_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)key = key.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)value = value.contiguous().view(src_len, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3)# [b,n_h,tgt_len,src_len] Scaled Dot-Product Attentionattention = query @ key.transpose(-2, -1)attention /= self.scale  # 参考: https://blog.csdn.net/zwhdldz/article/details/135462127attention = torch.softmax(attention, dim=-1)  # 行概率矩阵attention = self.att_dropout(input=attention)  # 正则化方法 DropKey,用于缓解 Vision Transformer 中的过拟合问题# [b,n_h,tgt_len,d_h] = [b,n_h,tgt_len,src_len] * [b,n_h,src_len,d_h]output = attention @ value# [b,n_h,tgt_len,d_h] -> [b,tgt_len,embed_dim]output = output.permute(0, 2, 1, 3).contiguous().view(tgt_len, bsz, embed_dim)# 头之间通过全连接融合一下output = self.fuse_heads(output)output = self.out_dropout(output)# shortcutoutput = output + identity# 多头head求平均if self.average_attn_weights:attention = attention.sum(dim=1) / self.num_heads# [tgt_len,b,embed_dim],[b,tgt_len,src_len]return output, attentionif __name__ == '__main__':query = torch.rand(size=(10, 2, 64))key = torch.rand(size=(5, 2, 64))value = torch.rand(size=(5, 2, 64))query_pos = torch.rand(size=(10, 2, 64))key_pos = torch.rand(size=(5, 2, 64))att = MultiheadAttention(64, 4)# 返回特征采样结果和attention矩阵output = att(query=query, key=key, value=value,query_pos=query_pos,key_pos=key_pos)pass

具体流程说明:

  1. 将input映射为qkv,如果是cross_attention,q与kv的行数可以不同,但列数(编码维度/通道数)必须相同
  2. q和v附加位置编码
  3. Scaled Dot-Product :通过计算Query和Key之间的点积除以scale得到注意力权重,经过dropout再与Value矩阵相乘得到输出。*scale和dropout的说明参考我的上一篇博客
  4. 输出的结果再通过线性变换融合多头信息。

在实现中,为了提高模型的表示能力和泛化性能,将Scaled Dot-Product Attention过程多次并行运行,形成多个头(head)。每个头分别进行注意力运算,输出的结果再通过线性变换和拼接操作组合在一起。每个头的权重矩阵是随机初始化生成的,并在训练过程中通过梯度下降等优化算法进行更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/621074.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SVN切换账户

前言&#xff08;svn切换&#xff09; 本文章简单写下SVN账户切换操作 linux 1.删除目录 ~/.subversion/auth/ 下的所有文件。 2.再次操作svn时可重新输入用户名和密码。 windows (1)在工程中单击右键,单击"TortoiseSVN"。 (2)选择"Setting"。 (3)选择&quo…

C语言实现快排核心思想(双指针法)

核心代码&#xff1a; 这就是每一趟快排的实现代码&#xff0c;由上面的动图&#xff0c;我们能知道前后指针法的核心是玩好cur和prev这两个指针&#xff0c;具体的逻辑是cur找比key小的值&#xff0c;找到就prev&#xff0c;然后prev和cur的值就进行交换&#xff0c;但是总不能…

统信UOS操作系统上禁用IPv6

原文链接&#xff1a;统信UOS操作系统上禁用IPv6 hello&#xff0c;大家好啊&#xff01;继之前我们讨论了如何在麒麟KYLINOS上禁用IPv6之后&#xff0c;今天我要给大家带来的是在统信UOS操作系统上禁用IPv6的方法。IPv6是最新的网络通信协议&#xff0c;但在某些特定的网络环境…

PiflowX-DorisWrite组件

DorisWrite组件 组件说明 往Doris存储写入数据。 计算引擎 flink 组件分组 doris 端口 Inport&#xff1a;默认端口 outport&#xff1a;默认端口 组件属性 名称展示名称默认值允许值是否必填描述例子fenodesFenodes“”无是Doris FE http地址&#xff0c; 支持多个…

基于企业级SaaS低代码平台的协同制造产品解决方案

万界星空科技低代码平台提供的MES&#xff0c;WMS&#xff0c;QMS等应用&#xff0c;是助力企业从数字化工厂向数字化企业升级的落地管道及载体&#xff0c;能帮助企业在数字化转型的过程中&#xff0c;实现制造企业与其供应链的协同制造。从订单发出、供应商确认、供应商生产、…

使用setdefault撰写文本索引脚本(出自Fluent Python案例)

背景介绍 由于我们主要介绍撰写脚本的方法&#xff0c;所以用一个简单的文本例子进行分析 a[(19,18),(20,53)] Although[(11,1),(16,1),(18,1)] ambiguity[(14,16)] 以上内容可以保存在一个txt文件中&#xff0c;任务是统计文件中每一个词&#xff08;包括字母&#xff0c;数…

评估LLM在细胞数据上的实用性(2)-细胞层面的评估

本文衔接上一篇&#xff1a;评估LLM在细胞数据上的实用性(1)-基本概述 目录 定义参数和任务批次整合多模态整合细胞类型注释 细胞层面的评估批次整合多模态整合细胞类型注释 定义 我们考虑一个预训练LLM表示为 M ( x , θ ) M(x,\theta) M(x,θ)&#xff0c;其基于单细胞数据…

RAG 评估框架 -- RAGAS

原文 引入 RAG&#xff08;Retrieval Augmented Generation&#xff09;的原因 随着ChatGPT的推出&#xff0c;很多人都理所当然直接用LLM当作知识库回答问题。这种想法有两个明显的缺点&#xff1a; LLM无法得知在训练之后所发生的事情&#xff0c;因此无法回答相关的问题存…

从零开始学习Python基础语法:打开编程大门的钥匙

文章目录 一、引言1.1 Python作为一种流行的编程语言的介绍1.2 Python的应用领域和适用性 二、为什么选择Python2.1 Python的简洁性和易读性2.2 Python的跨平台特性 三、Python在数据科学和人工智能领域的应用3.1 第一个Python程序3.1.1 Hello, World!&#xff1a;编写并运行你…

统信UOS_麒麟KYLINOS上使用Remmina远程Windows并传输文件

原文链接&#xff1a;统信UOS/麒麟KYLINOS上使用Remmina远程Windows并传输文件 hello&#xff0c;大家好啊&#xff01;继之前我们讨论了在统信UOS/麒麟KYLINOS与Windows之间通过Open SSH实现文件传输之后&#xff0c;今天我要给大家带来的是如何使用Remmina软件在统信UOS/麒麟…

12.2内核空间基于SPI总线的OLED驱动

在内核空间编写SPI设备驱动的要点 在SPI总线控制器的设备树节点下增加SPI设备的设备树节点&#xff0c;节点中必须包含 reg 属性、 compatible 属性、 spi-max-frequency 属性&#xff0c; reg 属性用于描述片选索引&#xff0c; compatible属性用于设备和驱动的匹配&#xff…

【数据结构】树和二叉树堆(基本概念介绍)

&#x1f308;个人主页&#xff1a;秦jh__https://blog.csdn.net/qinjh_?spm1010.2135.3001.5343&#x1f525; 系列专栏&#xff1a;《数据结构》https://blog.csdn.net/qinjh_/category_12536791.html?spm1001.2014.3001.5482 ​​ 目录 前言 树的概念 树的常见名词 树与…

2024.1.14每日一题

LeetCode 83.删除排序链表中的重复元素 83. 删除排序链表中的重复元素 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给定一个已排序的链表的头 head &#xff0c; 删除所有重复的元素&#xff0c;使每个元素只出现一次 。返回 已排序的链表 。 示例 1&#xff1a; 输…

关闭免费版pycharm社区版双击shift时出现的搜索框

Pycharm 在双击 shift 的时候总是弹出搜索框&#xff0c;但作为中国玩家&#xff0c;经常需要双击 shift 循环切换中英文。这就很困恼。 下面就解决这个问题。单独关闭双击shift的功能。 步骤 1.左上角 File -> Settings 2. 如图&#xff0c;输入‘advan’ 找到高级设置&…

RibbonGroup 添加QRadioButton

RibbonGroup添加QRadioButton&#xff1a; QRadioButton * pRadio new QRadioButton(tr("Radio")); pRadio->setToolTip(tr("Radio")); groupClipboard->addWidget(pRadio); connect(pRadio, SIGNAL(clicked(…

扩展卡尔曼滤波(Extended Kalman Filter, EKF):理论和应用

扩展卡尔曼滤波&#xff08;Extended Kalman Filter, EKF&#xff09;&#xff1a;理论、公式和应用 引言 卡尔曼滤波是一种广泛应用于估计动态系统状态的技术&#xff0c;但当系统的动态模型或测量模型是非线性的时候&#xff0c;传统的卡尔曼滤波方法就显得无能为力。扩展卡…

【保姆级教程|YOLOv8添加注意力机制】【1】添加SEAttention注意力机制步骤详解、训练及推理使用

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

SpringBoot+thymeleaf实战遇到的问题

目录 一、控制台&#xff1a; 二、数据库查询异常&#xff1a; 三、前后端错误校验 四、在serviceImp中需要添加一个eq条件&#xff0c;表示和数据库中的哪个字段进行比较&#xff0c;否则会查出所有数据&#xff0c;导致500 五、使用流转换数据更简洁 六、重复报错&…

动态规划篇-03:打家劫舍

198、打家劫舍 状态转移方程 base case 边界问题就是&#xff1a;走到最后一间房子门口也没抢&#xff0c;那么最终抢到的金额为0 明确状态 “原问题和子问题中会变化的变量” 抢到的金额数就是状态&#xff0c;因为随着在每一件房子门口做选择&#xff0c;抢到的金额数会随…

大模型训练营Day3 基于 InternLM 和 LangChain 搭建你的知识库 作业

本篇记录大模型训练营第三次的作业&#xff0c;属实是拖延症本症患者。 主要步骤前面的安装各种包和依赖如前面作业一样&#xff0c;按照文档操作即可&#xff1a; 再按照文档进行各种克隆&#xff0c;把知识库复制到本地&#xff1a; 复制粘贴操作文档中的构建向量数据库的文…