Vision Transformers的注意力层概念解释和代码实现

2017年推出《Attention is All You Need》以来,transformers 已经成为自然语言处理(NLP)的最新技术。2021年,《An Image is Worth 16x16 Words》,成功地将transformers 用于计算机视觉任务。从那时起,许多基于transformers的计算机视觉体系结构被提出。

本文将深入探讨注意力层在计算机视觉环境中的工作原理。我们将讨论单头注意力和多头注意力。它包括注意力层的代码,以及基础数学的概念解释。

在NLP应用中,注意力通常被描述为句子中单词(标记)之间的关系。而在计算机视觉应用程序中,注意力关注图像中patches (标记)之间的关系。

有多种方法可以将图像分解为一系列标记。原始的ViT²将图像分割成小块,然后将小块平摊成标记。《token -to- token ViT》³开发了一种更复杂的从图像创建标记的方法。

点积注意力

《Attention is All You Need》中定义的点积(相当于乘法)注意力是目前我们最常见也是最简单的一种中注意力机制,他的代码实现非常简单:

classAttention(nn.Module):
def__init__(self,
dim: int,
chan: int,
num_heads: int=1,
qkv_bias: bool=False,
qk_scale: NoneFloat=None):
""" Attention ModuleArgs:dim (int): input size of a single tokenchan (int): resulting size of a single token (channels)num_heads(int): number of attention heads in MSAqkv_bias (bool): determines if the qkv layer learns an addative biasqk_scale (NoneFloat): value to scale the queries and keys by;if None, queries and keys are scaled by ``head_dim ** -0.5``"""
super().__init__()
## Define Constants
self.num_heads=num_heads
self.chan=chan
self.head_dim=self.chan//self.num_heads
self.scale=qk_scaleorself.head_dim**-0.5
assertself.chan%self.num_heads==0, '"Chan" must be evenly divisible by "num_heads".'
## Define Layers
self.qkv=nn.Linear(dim, chan*3, bias=qkv_bias)
#### Each token gets projected from starting length (dim) to channel length (chan) 3 times (for each Q, K, V)
self.proj=nn.Linear(chan, chan)
defforward(self, x):
B, N, C=x.shape
## Dimensions: (batch, num_tokens, token_len)
## Calcuate QKVs
qkv=self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
#### Dimensions: (3, batch, heads, num_tokens, chan/num_heads = head_dim)
q, k, v=qkv[0], qkv[1], qkv[2]
## Calculate Attention
attn= (q*self.scale) @k.transpose(-2, -1)
attn=attn.softmax(dim=-1)
#### Dimensions: (batch, heads, num_tokens, num_tokens)
## Attention Layer
x= (attn@v).transpose(1, 2).reshape(B, N, self.chan)
#### Dimensions: (batch, heads, num_tokens, chan)
## Projection Layers
x=self.proj(x)
## Skip Connection Layer
v=v.transpose(1, 2).reshape(B, N, self.chan)
x=v+x
#### Because the original x has different size with current x, use v to do skip connection
returnx

 单头注意力

对于单个注意力头,让我们逐步了解向前传递每一个patch,使用7 * 7=49作为起始patch大小(因为这是T2T-ViT模型中的起始标记大小)。通道数64这也是T2T-ViT的默认值。然后假设有100标记,并且使用批大小为13进行前向传播(选择这两个数值是为了不会与任何其他参数混淆)。

# Define an Input
token_len=7*7
channels=64
num_tokens=100
batch=13
x=torch.rand(batch, num_tokens, token_len)
B, N, C=x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
A=Attention(dim=token_len, chan=channels, num_heads=1, qkv_bias=False, qk_scale=None)
A.eval();

 输入的维度是这样的:

Input dimensions are

  batchsize: 13

  number of tokens: 100

  token size: 49

 根据查询、键和值矩阵定义的。第一步是通过一个可学习的线性层来计算这些。qkv_bias项表示这些线性层是否有偏置项。这一步还将标记的长度从输入49更改为chan参数(64)。

qkv=A.qkv(x).reshape(B, N, 3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)
q, k, v=qkv[0], qkv[1], qkv[2]
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)

 可以看到 查询、键和值的维度是相同的,13代表批次,1是我们的注意力头数,100是我们输入的标记长度(序列长度),64是我们的通道数。

Dimensions for Queries are

  batchsize: 13

  attention heads: 1

  number of tokens: 100

  new length of tokens: 64

See that the dimensions for queries, keys, and values are all the same:

  Shape of Q: torch.Size([13, 1, 100, 64])

  Shape of K: torch.Size([13, 1, 100, 64])

  Shape of V: torch.Size([13, 1, 100, 64])

 我们看看可注意力是如何计算的,它被定义为:

 

Q、K、V分别为查询、键和值;dₖ是键的维数,它等于键标记的长度,也等于键的长度。

第一步是计算:

 

然后是

最后

Q·K的矩阵乘法看起来是这样的

 这些就是我们注意力的主要部分,代码是这样的

attn= (q*A.scale) @k.transpose(-2, -1)print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])

Dimensions for Attn are

  batchsize: 13

  attention heads: 1

  number of tokens: 100

  number of tokens: 100

 

 下一步就是计算A的softmax,这不会改变它的形状。

attn=attn.softmax(dim=-1)

 最后,我们计算出A·V=x:

x=attn@vprint('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])

 就得到了我们最终的结果

Dimensions for x are

  batchsize: 13

  attention heads: 1

  number of tokens: 100

  length of tokens: 64

 因为只有一个头,所以我们去掉头数 1

x = x.transpose(1, 2).reshape(B, N, A.chan)

 然后我们将x输入一个可学习的线性层,这个线性层不会改变它的形状。

x=A.proj(x)

最后我们实现的跳过连接

orig_shape= (batch, num_tokens, token_len)
curr_shape= (x.shape[0], x.shape[1], x.shape[2])
v=v.transpose(1, 2).reshape(B, N, A.chan)
v_shape= (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x=v+x
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

Original shape of input x: (13, 100, 49)

Current shape of x: (13, 100, 64)

Shape of V: (13, 100, 64)

After skip connection, dimensions for x are

  batchsize: 13

  number of tokens: 100

  length of tokens: 64

 

 多头注意力

我们可以扩展到多头注意。在计算机视觉中,这通常被称为多头自注意力(MSA)。我们不会详细介绍所有步骤,而是关注矩阵形状不同的地方。

对于多头的注意力,注意力头的数量必须可以整除以通道的数量,所以在这个例子中,我们将使用4个注意头。

# Define an Input
token_len=7*7
channels=64
num_tokens=100
batch=13
num_heads=4
x=torch.rand(batch, num_tokens, token_len)
B, N, C=x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
MSA=Attention(dim=token_len, chan=channels, num_heads=num_heads, qkv_bias=False, qk_scale=None)
MSA.eval();

 

Input dimensions are

  batchsize: 13

  number of tokens: 100

  token size: 49

计算查询、键和值的过程与单头的过程相同。但是可以看到标记的新长度是chan/num_heads。Q、K和V矩阵的总大小没有改变;它们的内容只是分布在头部维度上。你可以把它看作是将单个矩阵分割为多个:

 我们将子矩阵表示为Qₕ对于查询头i。

qkv=MSA.qkv(x).reshape(B, N, 3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)
q, k, v=qkv[0], qkv[1], qkv[2]
print('Head Dimension = chan / num_heads =', MSA.chan, '/', MSA.num_heads, '=', MSA.head_dim)
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)

 

Head Dimension = chan / num_heads = 64 / 4 = 16

Dimensions for Queries are

  batchsize: 13

  attention heads: 4

  number of tokens: 100

  new length of tokens: 16

See that the dimensions for queries, keys, and values are all the same:

  Shape of Q: torch.Size([13, 4, 100, 16])

  Shape of K: torch.Size([13, 4, 100, 16])

  Shape of V: torch.Size([13, 4, 100, 16])

 这里需要注意的是

 我们需要除以头数。num_heads = 4个不同的Attn矩阵,看起来像:

attn= (q*MSA.scale) @k.transpose(-2, -1)print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3]

 

Dimensions for Attn are

  batchsize: 13

  attention heads: 4

  number of tokens: 100

  number of tokens: 100

 softmax 不会改变维度,我们略过,然后计算每一个头

 这在多个注意头中是这样的:

attn = attn.softmax(dim=-1)x = attn @ vprint('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3]

Dimensions for x are

  batchsize: 13

  attention heads: 4

  number of tokens: 100

  length of tokens: 16

 

 最后需要维度重塑并把把所有的xₕ` s连接在一起。这是第一步的逆操作:

x=x.transpose(1, 2).reshape(B, N, MSA.chan)print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

 

Dimensions for x are

  batchsize: 13

  number of tokens: 100

  length of tokens: 64

 我们已经将所有头的输出连接在一起,注意力模块的其余部分保持不变。

x = MSA.proj(x)print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
orig_shape = (batch, num_tokens, token_len)
curr_shape = (x.shape[0], x.shape[1], x.shape[2])
v = v.transpose(1, 2).reshape(B, N, A.chan
v_shape = (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x = v + x    
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])

 

Dimensions for x are

  batchsize: 13

  number of tokens: 100

  length of tokens: 64

Original shape of input x: (13, 100, 49)

Current shape of x: (13, 100, 64)

Shape of V: (13, 100, 64)

After skip connection, dimensions for x are

  batchsize: 13

  number of tokens: 100

  length of tokens: 64

 总结

在这篇文章中我们完成了ViT中注意力层。为了更详细的说明我们进行了手动的代码编写,如果要实际的应用,可以使用PyTorch中的torch.nn. multiheadeattention(),因为他的实现要快的多。

最后参考文章:

[1] Vaswani et al (2017). Attention Is All You Need.https://doi.org/10.48550/arXiv.1706.03762

[2] Dosovitskiy et al (2020). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.https://doi.org/10.48550/arXiv.2010.11929

[3] Yuan et al (2021). Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet. https://doi.org/10.48550/arXiv.2101.11986GitHub code: https://github.com/yitu-opensource/T2T-ViT

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/737828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

刘敏:楼氏动铁和麦克风助力听力健康技术发展 | 演讲嘉宾公布

一、助辅听器材Ⅱ专题论坛 助辅听器材Ⅱ专题论坛将于3月28日同期举办! 听力贯穿人的一生,听觉在生命的各个阶段都是至关重要的功能,听力问题一旦出现,会严重影响生活质量。助辅听器材能有效提高生活品质。在这里,我们将…

Redis哨兵模式(Sentinel)的搭建与配置

创建三个Redis实例所需的目录,生产环境需独立部署在不同主机上,提高稳定性。 Redis 哨兵模式(Sentinel)是一个自动监控处理 redis 间故障节点转移工作的一个redis服务端实例,它不提供数据存储服务,只进行普通 redis 节点监控管理,使用redis哨兵模式可以实现redis服务端故…

八、软考-系统架构设计师笔记-系统质量属性和架构评估

1、软件系统质量属性 软件架构的定义 软件架构是指在一定的设计原则基础上,从不同角度对组成系统的各部分进行搭配和安排,形成系统的多个结构而组成架构,它包括该系统的各个构件,构件的外部可见属性及构件之间的相互关系。 软件架…

STM32串口:DMA空闲中断实现接收不定长数据(基于HAL库)

STM32串口:DMA空闲中断实现接收不定长数据(基于HAL库): 第一步:设置rcc,时钟频率,下载方式 设置system core->RCC如图所示:(即High Speed Clock和Low Speed Clock都选…

ansible基础与基础命令模块

一Ansible 1. ansible 的概念 Ansible是一个基于Python开发的配置管理和应用部署工具,现在也在自动化管理领域大放异彩。它融合了众多老牌运维工具的优点,Pubbet和Saltstack能实现的功能,Ansible基本上都可以实现。 Ansible能批量配置、部署、…

手机群控软件开发必备源代码分享!

随着移动互联网的飞速发展,手机群控技术在市场推广、自动化测试、应用管理等领域的应用越来越广泛,手机群控软件作为一种能够同时控制多台手机设备的工具,其开发过程中,源代码的编写显得尤为重要。 1、设备连接与识别模块 设备连…

java Day7 正则表达式|异常

文章目录 1、正则表达式1.1 常用1.2 字符串匹配,提取,分割 2、异常2.1 运行时异常2.2 编译时异常2.3 自定义异常2.3.1 自定义编译时异常2.3.2 自定义运行时异常 1、正则表达式 就是由一些特定的字符组成,完成一个特定的规则 可以用来校验数据…

AHU 汇编 实验二

一、实验名称:实验二 不同寻址方式的灵活运用 二、实验内容:定义数组a[6],用多种寻址方式访问对应元素,实现(a[0]a[1])*(a[2]-a[3])/a[4],将结果保存在内存a[5]中,用debug查询结果。 实验过程&a…

压缩自定义格式压缩包<2>:python使用DEFLATE 算法打包并解压成功,但是解压后的文件格式是固定后缀。

打包 import zlib import osdef compress_folder(input_folder, output_filename):"""使用 DEFLATE 算法压缩文件夹下的所有文件。Parameters:input_folder: str要压缩的文件夹路径。output_filename: str输出压缩文件名。"""# 创建一个空的字节…

GPT与R 在生态环境领域数据统计分析

原文链接:GPT与R 在生态环境领域数据统计分析https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247597092&idx2&sn0a7ac5cf03d37c7b4659f870a7b71a77&chksmfa823dc3cdf5b4d5ee96a928a1b854a44aff222c82b2b7ebb7ca44b27a621edc4c824115babe&…

Linux Centos系统 磁盘分区和文件系统管理 (深入理解)

CSDN 成就一亿技术人! 作者主页:点击! Linux专栏:点击! CSDN 成就一亿技术人! 前言———— 磁盘 在Linux系统中,磁盘是一种用于存储数据的物理设备,可以是传统的硬盘驱动器&am…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Progress)

进度条组件&#xff0c;用于显示内容加载或操作处理等进度。 说明&#xff1a; 该组件从API version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 Progress(options: ProgressOptions<Type>) 创建进度组件&a…

【好书推荐-第十一期】《Java面试八股文:高频面试题与求职攻略一本通(视频解说版)》(博文视点出品)

&#x1f60e; 作者介绍&#xff1a;我是程序员洲洲&#xff0c;一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主、前后端开发、人工智能研究生。公众号&#xff1a;洲与AI。 &#x1f388; 本文专栏&#xff1a;本文收录…

二,几何相交---4,BO算法---(2)比较和排序

在某一时刻xt&#xff0c;扫描线从左到右时&#xff0c;一部分线段会与扫描线相交&#xff0c;此时此刻&#xff0c;线段可以分成高低顺序&#xff0c; 那么对于给定两条线段&#xff0c;是如何变化的呢&#xff1f;有两个端点&#xff0c;左端点和右端点&#xff0c; 三种情况…

追寻工作与生活的和谐之道

在现代社会&#xff0c;人们往往被快节奏的工作和生活所困扰&#xff0c;如何在这两者之间找到平衡点&#xff0c;成为许多人关注的焦点。本文将为您介绍一些实用的方法和建议&#xff0c;帮助您实现工作与生活的和谐共处。 一、合理规划时间&#xff0c;提高工作效率 时间是实…

WorkPlus Meet提供高效、安全视频会议解决方案

WorkPlus Meet是一款私有部署和定制化的视频会议解决方案&#xff0c;为企业提供高效、安全的远程协作平台。随着全球数字化转型的加速&#xff0c;视频会议已成为企业必不可少的工作工具&#xff0c;而WorkPlus Meet的私有部署和定制化功能&#xff0c;为企业提供了更大的控制…

【MySQL系列 05】Schema 与数据类型优化

良好的数据库 schema 设计和合理的数据类型选择是 SQL 获得高性能的基石。 一、选择优化的数据类型 MySQL 支持的数据类型非常多&#xff0c;选择正确的数据类型对于获得高性能至关重要。不管存储哪种类型的数据&#xff0c;下面几个简单的原则都有助于做出更好的选择。 1. …

C语言学习-day19-函数2

自定义函数&#xff1a;自己定义的函数 以strcpy为例子&#xff1a; 自定义函数一样&#xff0c;需要函数名&#xff0c;返回值类型&#xff0c;函数参数。 函数的组成&#xff1a; ret_type fun_name(para1, *) { statement;//语句项 } ret_type 返回类型 fun_name 函数…

weiphp5.0存在远程代码执行漏洞

@[toc] 免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文章仅供学习用途使用。 1. weiphp5.0简介 微信公众号搜索:南风漏洞复…

c++ 开发环境 LNK1104: 无法打开文件“carve.lib” 已解决

别人分享&#xff0c; 和自己最近遇到问题一摸一样。以为没什么用的静态资源&#xff0c;结果 无法编译。 昨天安装配置了&#xff0c;结果今天早上打开电脑&#xff0c;所以dll的工程全部报错&#xff1a; 1>------ 已启动全部重新生成: 项目: Dll_test, 配置: Debug x64…