【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)


文章目录

  • 【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)
  • 1. 交叉注意力的起源与提出
  • 2. 交叉注意力的原理
  • 3. 交叉注意力的数学表示
  • 4. 交叉注意力的应用场景与发展
  • 5. 代码实现
  • 6. 代码解释
  • 7. 总结


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算: 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:
    在这里插入图片描述
    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出: 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1RT1×d X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2RT2×d,其中 T 1 T_1 T1 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:
在这里插入图片描述
其中, W q W_q Wq W k W_k Wk W v W_v Wv ∈ R d × d ∈R^{d×d} Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:
在这里插入图片描述

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

import torch
import torch.nn as nnclass CrossAttention(nn.Module):def __init__(self, dim, num_heads=8, dropout=0.1):super(CrossAttention, self).__init__()self.num_heads = num_headsself.dim = dimself.head_dim = dim // num_headsassert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"# 线性变换,用于生成 Q, K, V 矩阵self.q_proj = nn.Linear(dim, dim)self.k_proj = nn.Linear(dim, dim)self.v_proj = nn.Linear(dim, dim)# 输出的线性变换self.out_proj = nn.Linear(dim, dim)self.dropout = nn.Dropout(dropout)self.softmax = nn.Softmax(dim=-1)def forward(self, x1, x2):# x1 是 Query,x2 是 Key 和 ValueB, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]_, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]# 生成 Q, K, V 矩阵Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力得分attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_weights = self.softmax(attn_scores)  # 注意力权重attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合# 使用注意力权重加权值矩阵attn_output = attn_weights @ Vattn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)# 输出线性变换output = self.out_proj(attn_output)return output# 测试交叉注意力机制
if __name__ == "__main__":B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channelsx1 = torch.randn(B, T1, C)  # Query 输入x2 = torch.randn(B, T2, C)  # Key 和 Value 输入cross_attn = CrossAttention(dim=C, num_heads=4)output = cross_attn(x1, x2)print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/58709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安宝特案例 | AR技术在院外心脏骤停急救中的革命性应用

00 案例背景 在院外心脏骤停 (OHCA) 的突发救援中,时间与效率直接决定着患者的生命。传统急救模式下,急救人员常通过视频或电话与医院医生进行沟通,以描述患者状况并依照指令行动。然而,这种信息传递方式往往因信息不完整或传递延…

测试自动化工具的横向对比

MicroAgent:这个AI智能体一键写代码并且自动测试!它比 Aider 更好吗? 待研究:https://blog.csdn.net/gitblog_00054/article/details/139541862 playright:这是一个python驱动的自动化框架,部署简单&#…

第八天: C语言深度探索:指针与函数的奥秘

1 字符指针与字符串处理 字符指针与数组的关系 在C语言中,字符串可以被视为字符数组。数组名可以看作是指向其第一个元素的指针。因此,我们可以使用相同类型的指针来访问数组中的每个元素。这种对应关系使得字符指针在处理字符串时显得特别方便。 c …

C++初阶——类与对象(上篇)

一、写在前面 类与对象是C不同于C语言的一个板块,内容很多,笔者把这部分分为三篇博客来讲解,希望能够帮助各位读者更容易地理解这些知识点。弄清楚这一部分之后,C就算是成功入门了。 二、面向过程和面向对象 C语言就是典型的面向…

Java如何实现PDF转高质量图片

大家好,我是 V 哥。在Java中,将PDF文件转换为高质量的图片可以使用不同的库,其中最常用的库之一是 Apache PDFBox。通过该库,你可以读取PDF文件,并将每一页转换为图像文件。为了提高图像的质量,你可以指定分…

论文略读:OneChart: Purify the Chart Structural Extraction via One Auxiliary Token

2024 旷视的work 图表解析模型 1 背景 对于之前的视觉语言模型,论文认为其有两点不足需要改进: 需要充分训练一个真正会看 chart 的 vision encoder单纯对文本输出算交叉熵损失,并不是最优的(如上图所示,当ground-tr…

STM32CubeMX学习(三) SPI+DMA通信

STM32CubeMX学习(三) SPIDMA通信 一、简介二、新建STM32CubeMX项目并使用外部时钟三、SPI3配置四、相关代码五、测试 一、简介 本文将基于STM32F103RCT芯片介绍如何在STM32CubeMXKEIL5开发环境下进行SPIDMA通信。 操作系统:WIN10 x64硬件电…

iOS静态库(.a)及资源文件的生成与使用详解(OC版本)

引言 iOS静态库(.a)及资源文件的生成与使用详解(Swift版本)_xcode 合并 .a文件-CSDN博客 在前面的博客中我们已经介绍了关于iOS静态库的生成步骤以及关于资源文件的处理,在本篇博客中我们将会以Objective-C为基础语言…

Python爬虫:在1688上“拍立淘”——按图索骥的奇妙之旅

想象一下,你是一名古代的侦探,手中握着一张神秘的藏宝图,在1688的茫茫商品海洋中寻找与之匹配的宝藏。今天,我们将一起化身为代码界的“拍立淘”专家,使用Python爬虫技术,通过API接口按图搜索商品。准备好你…

如何在小红书发布笔记时显示外地IP地址

小红书平台在发布笔记时显示IP地址可能是由于网络爬虫或者某些技术手段抓取数据时所导致的。为了保护用户隐私和安全,显示外地IP地址,可以尝试以下几种方法: 1.检查发布环境: 确保你是在一个安全、可信的网络环境下发布笔记&…

Linux中查询Redis中的key和value(没有可视化工具)

1.进入redis安装目录 进入redis安装目录,找到redis-cli(redis的客户端) 2.登录redis客户端 登录redis的客户端,格式:redis-cli -h [host] -p [port] -a [password],懂的都懂!!! ./redis-cli -h 192.168.8.101 -p 6380 -a xxxx登录成功后就这样子 3.查看redis中所有的key和…

Unity Editor 快速移动资源

Editor 快速移动资源 🍔使用场景🌭功能 🍔使用场景 一般想要移动一个资源到另一个目录的办法是选中资源拖拽过去, 但在一个比较大的项目中你得一直拖啊拖直到找到那个目录 🤯。 使用本插件就可以省去拖拽的步骤&#…

特斯联巨亏数十亿:毛利率剧烈波动下滑,高管动荡引发关注

《港湾商业观察》施子夫 近期,重庆特斯联智慧科技股份有限公司(以下简称,特斯联)递表港交所,联席保荐机构中信证券和海通国际。 此番闯关港交所,特斯联三年半巨亏超70亿元,公司何时能扭亏为盈…

javaweb----VS code

前端开发神器:VS Code → 速度快、体积小、插件多 VS Code 安装官网:https://code.visualstudio.com/download VS Code一些必备的插件安装: 1、Chinese (Simplified) 简体中文 2、Code Spell Checker 检查拼写 3、HTML CSS Support 4…

使用 Kafka 和 MinIO 实现人工智能数据工作流

MinIO Enterprise Object Store 是用于创建和执行复杂数据工作流的基础组件。此事件驱动功能的核心是使用 Kafka 的 MinIO 存储桶通知。MinIO Enterprise Object Store 为所有 HTTP 请求(如 PUT、POST、COPY、DELETE、GET、HEAD 和 CompleteMultipartUpload&#xf…

Linux: network: hw csum failure

文章目录 简介openvswitchifb: fix packets checksummellanoxiommu=ptmellanox 2mellanox 3建议简介 这里总结一下几个checksum的问题,仅供参考。需要看所使用的系统是否已经有了相应的fix。也可能是一个新问题,如果是新问题,恭喜发现了新的宝藏。 openvswitch Fix setti…

【Python】数据容器详解:列表、元组、字典与集合的推导式与公共方法

目录 🍔 列表集合字典的推导式 1.1 什么是推导式 1.2 为什么需要推导式 1.3 列表推导式 1.4 列表推导式 if条件判断 1.5 for循环嵌套列表推导式 1.6 字典推导式 1.7 集合推导式 🍔 数据序列中的公共方法 2.1 什么是公共方法 2.2 常见公共方法…

【PythonWeb开发】Flask-RESTful视图类基础知识

flask_restful 是一个扩展库,它为 Flask 提供了快速构建 RESTful API 的功能。使用 flask_restful 可以简化 RESTful API 的开发过程,减少样板代码,并且提供了一些高级特性,如 HTTP 方法的映射、资源路由的定义等。 在flask_restf…

1 C++ 编译属性 __attribute__((X))

__attribute__是GNU对标准C的扩展,可以用来设置函数属性(Function Attribute)、变量属性(Variable Attribute)和类型属性(Type Attribute)。 __attribute__使用 1 __attribute__((used))函数…

uniapp 使用uni.getRecorderManager录音,wav格式采样率低于44100,音频播放不了问题解决

如题:uniapp开发app端,使用uni.getRecorderManager录wav格式音频,采样率8000/16000都无法播放,44100可以播放。但由于项目需求需要录制采样率为8000的音频,于是引用了如下插件 插件地址(具体可以参考该插件的使用说明…