【深度学习】注意力机制的改进:稀疏注意力、局部注意力、低秩/线性注意力

文章目录

  • 稀疏注意力
      • PVT v2中的稀疏注意力
        • 公式
      • 代码示例
  • 局部注意力
      • 局部注意力
      • Swin Transformer中的基于窗口的注意力
        • 公式
      • 代码示例
      • 窗口移位操作(Shifted Window)
      • 实现细节
      • 公式
      • 代码示例
  • 低秩/线性注意力
      • 低秩/线性注意力
      • Linformer
        • 公式
      • 代码示例
      • 代码解释

稀疏注意力

稀疏注意力(Sparse Attention)是一种通过选择性地处理部分token来减少整体计算负荷的方法。这在自然语言处理和计算机视觉中的注意力机制中尤为重要,因为它可以显著降低计算复杂度和内存使用。

在标准的全连接注意力机制中,每个token(词或图像patch)都与其他所有token计算注意力权重,这会导致计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是token的数量。这种全连接的计算在处理长序列或高分辨率图像时会非常耗时且内存消耗巨大。稀疏注意力则通过只计算部分token之间的注意力权重,从而将复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN) O ( N ) O(N) O(N)

PVT v2中的稀疏注意力

PVT v2(Pyramid Vision Transformer v2)是一种改进的视觉Transformer模型,它通过使用卷积核来压缩key和value的空间,从而降低计算注意力的复杂性。这意味着在计算注意力权重时,模型不需要考虑所有token,而只考虑压缩后的key和value。

公式

标准注意力机制计算如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • ( Q ) 是查询矩阵(query)
  • ( K ) 是键矩阵(key)
  • ( V ) 是值矩阵(value)
  • ( d_k ) 是键的维度

在PVT v2中,通过使用卷积核对key和value进行空间压缩,这一过程可以表示为:

K ′ = Conv ( K ) K' = \text{Conv}(K) K=Conv(K)

V ′ = Conv ( V ) V' = \text{Conv}(V) V=Conv(V)

其中 (\text{Conv}) 表示卷积操作。这样,计算注意力时使用的是压缩后的key和value:

SparseAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{SparseAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' SparseAttention(Q,K,V)=softmax(dk QKT)V

代码示例

下面是一个使用PyTorch实现的简化版本的稀疏注意力机制:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SparseAttention(nn.Module):def __init__(self, d_model, d_k, kernel_size=3, stride=1, padding=1):super(SparseAttention, self).__init__()self.query_conv = nn.Linear(d_model, d_k)self.key_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)self.value_conv = nn.Conv2d(d_model, d_k, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, q, k, v):q = self.query_conv(q)  # (batch_size, seq_len, d_k)# Assuming k and v are of shape (batch_size, d_model, height, width)k = self.key_conv(k)  # (batch_size, d_k, new_height, new_width)v = self.value_conv(v)  # (batch_size, d_k, new_height, new_width)# Flatten spatial dimensionsk = k.flatten(2)  # (batch_size, d_k, new_height * new_width)v = v.flatten(2)  # (batch_size, d_k, new_height * new_width)# Compute attention weightsattn_weights = F.softmax(torch.bmm(q, k.transpose(1, 2)) / (k.size(1) ** 0.5), dim=-1)  # (batch_size, seq_len, new_height * new_width)# Compute attention outputattn_output = torch.bmm(attn_weights, v.transpose(1, 2))  # (batch_size, seq_len, d_k)return attn_output# Example usage
batch_size = 2
seq_len = 10
d_model = 64
d_k = 32
height = width = 16q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, d_model, height, width)
v = torch.randn(batch_size, d_model, height, width)sparse_attention = SparseAttention(d_model, d_k)
output = sparse_attention(q, k, v)
print(output.shape)  # (batch_size, seq_len, d_k)

这个示例展示了如何使用卷积操作压缩key和value,并在稀疏注意力机制中计算注意力权重和输出。通过这种方法,可以显著减少计算复杂度,提高模型的效率。

局部注意力

局部注意力(Local Attention)是一种通过将注意力集中在输入序列或图像的局部区域上来减少计算负荷的方法。这种方法通过限制每个token仅与其附近的token计算注意力,从而降低了计算复杂度。

局部注意力

局部注意力机制的主要思想是将输入划分为若干个固定大小的窗口,然后在每个窗口内独立计算注意力权重。这种方法适用于处理长序列或高分辨率图像时,因为它能够显著减少计算量,同时保留足够的上下文信息。

Swin Transformer中的基于窗口的注意力

Swin Transformer(Shifted Window Transformer)是一种基于局部注意力的Transformer模型,它通过引入基于窗口的注意力机制,将计算限制在指定的窗口大小内进行。具体来说,Swin Transformer将输入图像划分为多个非重叠的窗口,然后在每个窗口内独立计算注意力。为了进一步增强模型的全局上下文捕捉能力,Swin Transformer还引入了窗口的移位操作。

公式

标准的全连接注意力计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

在Swin Transformer中,输入图像被划分为多个窗口,每个窗口内计算局部注意力:

LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw

其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。

为了捕捉全局信息,Swin Transformer引入了窗口移位操作(Shifted Window)。通过在每个注意力计算层之间对窗口进行移位,可以实现窗口之间的信息交互。

代码示例

下面是一个使用PyTorch实现的简化版基于窗口的局部注意力机制:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass WindowAttention(nn.Module):def __init__(self, d_model, window_size):super(WindowAttention, self).__init__()self.window_size = window_sizeself.query_conv = nn.Linear(d_model, d_model)self.key_conv = nn.Linear(d_model, d_model)self.value_conv = nn.Linear(d_model, d_model)self.proj = nn.Linear(d_model, d_model)def forward(self, x):B, N, C = x.shapex = x.view(B, N, C)windows = self.window_partition(x)# Apply linear transformationsQ = self.query_conv(windows)K = self.key_conv(windows)V = self.value_conv(windows)# Calculate attention within each windowattn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)attn_output = torch.matmul(attn_weights, V)# Concatenate windows back to original shapeattn_output = self.window_reverse(attn_output, B, N)return self.proj(attn_output)def window_partition(self, x):B, N, C = x.shapewindow_size = self.window_sizex = x.view(B, int(N**0.5), int(N**0.5), C)  # Assume input is a square imagewindows = x.unfold(1, window_size, window_size).unfold(2, window_size, window_size)windows = windows.contiguous().view(-1, window_size * window_size, C)return windowsdef window_reverse(self, windows, B, N):window_size = self.window_sizeC = windows.shape[-1]windows = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)x = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)return x# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)
output = window_attention(x)
print(output.shape)  # (batch_size, seq_len, d_model)

这个代码实现了一个简化的基于窗口的局部注意力机制。在这里,我们假设输入是一个形状为 (B \times N \times C) 的张量,其中 (B) 是批次大小,(N) 是序列长度(例如图像被展平后的长度),(C) 是每个token的维度。窗口大小由 window_size 参数决定。通过这种方法,可以有效减少计算复杂度,同时保持较好的局部上下文信息。

窗口移位操作(Shifted Window)

窗口移位操作是Swin Transformer中的一个关键技术,它通过在注意力计算层之间对窗口进行移位,实现窗口之间的信息交互,从而增强模型的全局上下文捕捉能力。

在标准的窗口注意力机制中,每个窗口内的token只与同一窗口内的其他token进行注意力计算。这虽然减少了计算复杂度,但限制了跨窗口的信息交流。为了克服这一限制,Swin Transformer引入了窗口移位操作。在每个注意力计算层之间,对窗口进行移位,使得原本不在同一个窗口内的token可以在后续的注意力计算中进行交互。

实现细节

假设输入图像被划分为大小为 ( M \times M ) 的窗口,在第一个注意力计算层中,窗口是不重叠的。在第二个注意力计算层之前,对窗口进行移位(例如,水平和垂直方向各移位 ( M/2 ) 个像素),然后再进行注意力计算。

公式

标准的基于窗口的注意力计算公式为:

LocalAttention ( Q w , K w , V w ) = softmax ( Q w K w T d k ) V w \text{LocalAttention}(Q_w, K_w, V_w) = \text{softmax}\left(\frac{Q_w K_w^T}{\sqrt{d_k}}\right)V_w LocalAttention(Qw,Kw,Vw)=softmax(dk QwKwT)Vw

其中 (Q_w), (K_w), 和 (V_w) 分别是窗口内的查询、键和值矩阵。

在引入窗口移位操作后,计算公式保持不变,但窗口的划分方式在每层之间会有所不同。

代码示例

下面是一个包含窗口移位操作的局部注意力机制的PyTorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass WindowAttention(nn.Module):def __init__(self, d_model, window_size):super(WindowAttention, self).__init__()self.window_size = window_sizeself.query_conv = nn.Linear(d_model, d_model)self.key_conv = nn.Linear(d_model, d_model)self.value_conv = nn.Linear(d_model, d_model)self.proj = nn.Linear(d_model, d_model)def forward(self, x, shift_size=0):B, N, C = x.shapex = x.view(B, int(N**0.5), int(N**0.5), C)  # Assume input is a square imageif shift_size > 0:x = self.shift_window(x, shift_size)  # Shift the windowwindows = self.window_partition(x)# Apply linear transformationsQ = self.query_conv(windows)K = self.key_conv(windows)V = self.value_conv(windows)# Calculate attention within each windowattn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)attn_output = torch.matmul(attn_weights, V)# Concatenate windows back to original shapeattn_output = self.window_reverse(attn_output, B, N)if shift_size > 0:attn_output = self.reverse_shift_window(attn_output, shift_size)  # Reverse the shiftreturn self.proj(attn_output)def window_partition(self, x):B, H, W, C = x.shapewindow_size = self.window_sizex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)return windowsdef window_reverse(self, windows, B, N):window_size = self.window_sizeC = windows.shape[-1]x = windows.view(B, int(N**0.5) // window_size, int(N**0.5) // window_size, window_size, window_size, C)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, N, C)return xdef shift_window(self, x, shift_size):B, H, W, C = x.shapex = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))return xdef reverse_shift_window(self, x, shift_size):B, H, W, C = x.shapex = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))return x# Example usage
batch_size = 2
seq_len = 16
d_model = 64
window_size = 2x = torch.randn(batch_size, seq_len, d_model)
window_attention = WindowAttention(d_model, window_size)# Without shift
output_no_shift = window_attention(x)
print(output_no_shift.shape)  # (batch_size, seq_len, d_model)# With shift
output_shift = window_attention(x, shift_size=1)
print(output_shift.shape)  # (batch_size, seq_len, d_model)

这个代码实现了一个简化的基于窗口的局部注意力机制,并且包含了窗口移位操作。窗口移位操作通过 shift_windowreverse_shift_window 方法实现。在注意力计算前进行窗口移位,注意力计算后再将窗口移回原位置。这样可以有效地增强跨窗口的信息交互能力,从而提升模型的全局上下文捕捉能力。

低秩/线性注意力

低秩/线性注意力(Low-Rank/Linear Attention)是一种通过对自注意力机制进行低秩近似来减少计算复杂性的方法。Linformer 是一种具体的实现,它通过低秩近似大大降低了自注意力机制的计算和内存需求。

低秩/线性注意力

在标准的自注意力机制中,每个查询(query)与所有键(key)计算注意力权重,计算复杂度为 O ( N 2 ) O(N^2) O(N2),其中 N N N 是序列长度。低秩/线性注意力通过对注意力矩阵进行低秩近似,将计算复杂度降低为 O ( N ) O(N) O(N)

Linformer

Linformer 是一种通过低秩近似优化自注意力机制的模型。它的主要思想是将高维的键(key)和值(value)投影到一个低维空间,从而减少计算量。

公式

标准的自注意力计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

在Linformer中,键和值矩阵被投影到一个低维空间:

K ′ = E K , V ′ = E V K' = EK, \quad V' = EV K=EK,V=EV

其中,(E) 是一个投影矩阵,将原始的高维键和值矩阵 (K) 和 (V) 投影到一个低维空间。

因此,Linformer的自注意力计算公式变为:

LinformerAttention ( Q , K ′ , V ′ ) = softmax ( Q K ′ T d k ) V ′ \text{LinformerAttention}(Q, K', V') = \text{softmax}\left(\frac{QK'^T}{\sqrt{d_k}}\right)V' LinformerAttention(Q,K,V)=softmax(dk QKT)V

代码示例

下面是一个使用PyTorch实现的Linformer低秩注意力机制的简化版本:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LinformerAttention(nn.Module):def __init__(self, d_model, seq_len, k_dim):super(LinformerAttention, self).__init__()self.seq_len = seq_lenself.k_dim = k_dimself.query_proj = nn.Linear(d_model, d_model)self.key_proj = nn.Linear(d_model, d_model)self.value_proj = nn.Linear(d_model, d_model)self.E = nn.Parameter(torch.randn(seq_len, k_dim))self.proj = nn.Linear(d_model, d_model)def forward(self, q, k, v):B, N, C = q.shapeQ = self.query_proj(q)K = self.key_proj(k)V = self.value_proj(v)# Project keys and values to low-dimensional spaceK = torch.matmul(self.E.T, K)V = torch.matmul(self.E.T, V)# Calculate attentionattn_weights = F.softmax(torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5), dim=-1)attn_output = torch.matmul(attn_weights, V)return self.proj(attn_output)# Example usage
batch_size = 2
seq_len = 16
d_model = 64
k_dim = 8q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)linformer_attention = LinformerAttention(d_model, seq_len, k_dim)
output = linformer_attention(q, k, v)
print(output.shape)  # (batch_size, seq_len, d_model)

代码解释

  1. LinformerAttention类: 定义了Linformer低秩注意力机制。

    • __init__方法中,定义了投影矩阵 E 和用于查询、键、值的线性变换 query_proj, key_proj, value_proj
    • forward方法中,首先对查询、键和值进行线性变换,然后使用投影矩阵 E 将键和值投影到低维空间,最后计算注意力权重和输出。
  2. Example usage: 创建输入张量 q, k, vLinformerAttention 实例,并进行前向计算。结果输出张量的形状为 (batch_size, seq_len, d_model)

通过这种方法,Linformer能够有效降低自注意力机制的计算复杂性和内存使用,同时保持良好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

事务的影子拷贝-系统架构师(二十)

1、(重点)企业信息集成按照组织范围分为企业内部的信息集成和外部信息集成。在企业内部信息集成中,()实现了不同系统之间的互操作,使的不同系统之间能够实现数据和方法的共享。()实现…

Unity小知识

1.当我们把摄像机的内容渲染到RenderTexture上而不是屏幕上时,那么相机的Aspect默认会设置成和RenderTexture的分辨率一样.不过最终如果把RenderTexture作为贴图贴到模型上去的时候还是会被UV拉伸和缩小的。 2.要想自定义UnityPackage的内容,只要找到UnityProject/L…

H5与小程序:两者有何不同?

H5,即HTML5,是构建Web内容的一种语言描述方式,也是互联网的下一代标准,被认为是互联网的核心技术之一。HTML5是在HTML4.01的基础上进行了一定的改进后的规范,用户在使用任何手段进行网页浏览时看到的内容原本都是HTML格…

GPT对话代码库——HAL库下 USART 的配置及问题(STM32G431CBT6)

目录 1,问: 1,答: 示例代码 正确的HAL库初始化方式 自定义初始化方式(不推荐) 总结 2,问: 2,答: 代码详细解释 初始部分 主初始化部分 初始化调用…

QT学习积累——如何提高Qt遍历list的效率

目录 引出Qt遍历list提高效率显示函数的调用使用&与不使用&除法的一个坑 总结自定义信号和槽1.自定义信号2.自定义槽3.建立连接4.进行触发 自定义信号重载带参数的按钮触发信号触发信号拓展 lambda表达式返回值mutable修饰案例 引出 QT学习积累——如何提高Qt遍历list…

python 操作网页

使用selenium库获取网页元素的属性值是一个常见的需求。以下是一个Python代码示例,展示了如何使用selenium来获取一个链接的href属性以及一个输入框的value属性。 首先,请确保您已经安装了selenium库,并且配置了WebDriver(如ChromeDriver)以驱动浏览器。 pythonfrom sele…

如何避免Java中的内存泄漏?

如何避免Java中的内存泄漏? 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在Java开发中,内存泄漏(Memory Leak&#xff0…

CSF视频文件格式转换WMV格式(2024年可用)

如果大家看过一些高校教学讲解视频的话,很可能见过这样一个难得的格式,".csf ",非常漂亮 。 用暴风影音都可以打开观看,会自动下载解码。 但是一旦我们想要利用或者上传视频的时候就麻烦了,一般网站不认这…

STM32重定向printf到串口(重写fputc不生效)

使用STM32开发,想用printf把输出打印到串口,需要重定向printf函数。 网上一搜全都是重写fpuc的,但这只针对使用了MicroLIB的情况,如果你使用STM32CubeMX配置了CMake或者Makefile项目,这种方法是根本不可行的&#xff0…

为什么PS5运行游戏的效果往往比号称更强大的Xbox Series X更好?

在第九代游戏机即将进入第四个年头之际,有一个问题仍未得到解答:索尼的 PS5 游戏机的性能如何经常超越纸面性能更强大的微软 Xbox X 系列? 几个明显的例子包括《生化危机 4》、《使命召唤:黑色行动:冷战》和新一代更新…

【支撑文档】系统安全保证措施(word原件)

软件安全保证措施word 软件所有全套资料获取进主页或者本文末个人名片直接。

跨平台营销的智能协同:Kompas.ai如何整合多渠道策略

引言 在数字化营销的今天,消费者的注意力分散在多个平台上。品牌要想有效地吸引和保持消费者的关注,就必须采取跨平台营销策略。Kompas.ai,作为一款智能营销工具,能够帮助品牌实现这一目标。 跨平台营销的重要性 跨平台营销能够…

智慧园区大数据云平台建设方案(Word原件)

第一章 项目建设背景及现状 第二章 园区创新发展趋势 第三章 工业园区大数据存在的问题 第四章 智慧工业园区大数据建设目的 第五章 智慧园区总体构架 第六章 系统核心组件 第七章 智慧工业园区大数据平台规划设计 获取方式:本文末个人名片直接获取。 软件资料清单…

【报错】安装clang-14 的时候,报错E: 无法定位软件包 clang-14

1 报错 安装clang-14,命令如下: sudo apt-get install clang-14 报错为E: 无法定位软件包 clang-14 2 解决方法 使用其他的安装方法,命令如下: wget https://apt.llvm.org/llvm.sh # 添加权限chmod +x llvm.shsudo ./llvm.sh 14 all # 卸载第3步安装过程中安装无用的…

mysql 获取枚举的随机值

mysql 获取枚举的随机值 1.需求描述2.使用到的函数elt函数语法示例 RAND() 函数FLOOR()函数 3.解决方案手写生成:少量数据从表中生成:多数据 4.实战 1.需求描述 在MySQL中,您可以使用ENUM类型定义列,并且可以从中选择随机值。但是…

超融合服务器挂载硬盘--linux系统

项目中需要增加服务器的硬盘容量,通过超融合挂载了硬盘后,还需要添加到指定的路径下,这里记录一下操作步骤。 一:通过管理界面挂载硬盘 这一步都是界面操作,登录超融合控制云台后,找到对应的服务器&#…

uniapp中实现跳转到外部链接(也就是a标签的功能)

uniapp中实现跳转到外部链接(也就是a标签的功能) 项目中需要做到跳转到外部链接,网上找了很多都不是很符合自己的要求,需要编译成app后是跳转到游览器打开链接,编译成web是在新窗口打开链接。实现的代码如下&#xff1…

矩阵、混剪、大盘,3大功能升级优化!助力企业高效管理!

在数字化转型的浪潮中,企业对于工具与技术的需求愈发强烈。 为满足市场需求,本月【云略】为各企业上线了便捷功能,赋能企业经营决策和业务增长。 矩阵管理 √【矩阵号管理】抖音支持设置城市IP 内容管理 √【混剪任务】支持关联智能发布计…

PDF文档如何统计字数,统计PDF文档字数的方法有哪些?

在平时使用pdf阅读或者是处理文档的时候,常常需要统计文档的字数。pdf在查看文字时其实很简单。 PDF文档是一种常见的电子文档格式,如果需要对PDF文档中的字数进行统计,可以使用以下方法: Adobe Acrobat DC:Adobe Ac…

AI大模型,爆发了

随着ChatGPT用户增速放缓,AI创业公司马太效应加剧,第一轮AI投资热潮逐渐褪去,AI大模型进入“冷静期”。擅长后发制人的腾讯,姗姗来迟,推出了混元大模型,为这一轮AI热潮画上了句号。 AI大模型,开…