Transformer 中 Self-Attention 的二次方复杂度(Quadratic Complexity )问题及改进方法:中英双语

Transformer 中 Self-Attention 的二次方复杂度问题及改进方法

随着大型语言模型(LLM)输入序列长度的增加,Transformer 结构中的核心模块——自注意力机制(Self-Attention) 的计算复杂度和内存消耗都呈现二次方增长。这不仅限制了模型处理长序列的能力,也成为训练和推理阶段的重要瓶颈。

本篇博客将详细解释 Transformer 中 Self-Attention 机制的二次方复杂度来源,结合代码示例展示这一问题,并介绍一些常见的改进方法。


1. Self-Attention 机制简介

原理与公式

在自注意力(Self-Attention)机制中,输入序列 ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) 被映射到三个向量:查询(Query) ( Q Q Q )、键(Key) ( K K K ) 和 值(Value) ( V V V ),三者通过权重矩阵 ( W Q W_Q WQ )、( W K W_K WK )、( W V W_V WV ) 得到:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV

自注意力输出的计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

  • ( n n n ) 是输入序列的长度(token 数量)。
  • ( d d d ) 是输入特征的维度。
  • ( d k d_k dk ) 是键向量的维度(通常 ( d k = d / h d_k = d / h dk=d/h ),其中 ( h h h ) 是多头注意力的头数)。

时间复杂度分析

从公式可以看出,自注意力机制中的关键操作是:

  1. ( Q K T Q K^T QKT ):查询向量 ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk ) 与键向量 ( K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} KRn×dk ) 相乘,得到 ( n × n n \times n n×n ) 的注意力分数矩阵。

    • 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) )。
  2. softmax 操作:在 ( n × n n \times n n×n ) 的注意力矩阵上进行归一化,复杂度为 ( O ( n 2 ) O(n^2) O(n2) )。

  3. 注意力分数与 ( V V V ) 相乘:将 ( n × n n \times n n×n ) 的注意力分数矩阵与 ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv ) 相乘,复杂度为 ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) )。

综上,自注意力机制的时间复杂度为:

O ( n 2 d k + n 2 + n 2 d v ) ≈ O ( n 2 d ) O(n^2 d_k + n^2 + n^2 d_v) \approx O(n^2 d) O(n2dk+n2+n2dv)O(n2d)

  • 当 ( d d d ) 是常数时,复杂度主要取决于输入序列的长度 ( n n n ),即呈二次方增长

空间复杂度分析

自注意力的注意力分数矩阵 ( Q K T Q K^T QKT ) 具有 ( n × n n \times n n×n ) 的大小,需要 ( O ( n 2 ) O(n^2) O(n2) ) 的内存进行存储。


2. 代码示例:计算复杂度与空间消耗

以下代码展示了输入序列长度增加时,自注意力机制的时间和空间消耗情况:

import torch
import time# 定义自注意力机制
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# 测试输入序列长度不同的时间复杂度
def test_attention_complexity():d_k = 64  # 特征维度for n in [128, 256, 512, 1024, 2048]:  # 输入序列长度Q = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

运行结果示例

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

从结果可以看出,随着序列长度的增加,计算时间呈现明显的二次方增长。


3. 二次方复杂度的改进方法

为了减少自注意力机制的计算复杂度,许多研究者提出了优化方案,主要包括:

1. 低秩近似方法

利用低秩矩阵分解减少 ( Q K T Q K^T QKT ) 的计算复杂度,例如:

  • Linformer:将 ( n × n n \times n n×n ) 的注意力矩阵通过低秩分解近似为 ( n × k n \times k n×k )(其中 ( k ≪ n k \ll n kn )),复杂度降为 ( O ( n k ) O(nk) O(nk) )。

2. 稀疏注意力(Sparse Attention)

  • LongformerBigBird:通过引入局部窗口和全局注意力机制,仅计算部分注意力分数,避免完整的 ( Q K T Q K^T QKT ) 计算,将复杂度降低为 ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ) 或 ( O ( n ) O(n) O(n) )。

3. 线性注意力(Linear Attention)

  • Performer:使用核技巧将自注意力计算转化为线性操作,复杂度降为 ( O ( n d ) O(n d) O(nd) )。

4. 分块方法(Blockwise Attention)

将输入序列分成多个块,仅在块内或块间进行注意力计算,适用于长序列任务。


4. 总结

在 Transformer 的自注意力机制中,由于需要计算 ( Q K T Q K^T QKT ) 和存储 ( n × n n \times n n×n ) 的注意力矩阵,其时间和空间复杂度均为 ( O ( n 2 ) O(n^2) O(n2) )。这对于处理长序列任务(如长文本、DNA 序列分析等)来说是一个显著的挑战。

为了解决这一问题,近年来提出了多种优化方法,包括低秩近似、稀疏注意力、线性注意力等,成功将复杂度从 ( O ( n 2 ) O(n^2) O(n2) ) 降低到 ( O ( n ) O(n) O(n) ) 或 ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ),从而使 Transformer 更加高效地处理长序列任务。

代码示例和实验结果清楚地展示了二次方复杂度的实际影响,同时也强调了优化方法的重要性。

英文版

The Quadratic Complexity of Self-Attention in Transformers and Possible Improvements

The core of the Transformer architecture in large language models (LLMs) is the self-attention mechanism. While it has proven revolutionary, its computational complexity and memory requirements grow quadratically as the input sequence length increases. This blog will explain the source of this quadratic complexity, demonstrate it with code, and discuss possible optimization methods.


1. Understanding Self-Attention

Mathematical Formulation

Given an input sequence ( X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d ) with ( n n n ) tokens and ( d d d ) features, the self-attention mechanism computes the query (Q), key (K), and value (V) matrices as follows:

Q = X W Q , K = X W K , V = X W V Q = X W_Q, \quad K = X W_K, \quad V = X W_V Q=XWQ,K=XWK,V=XWV

The output of the self-attention mechanism is calculated as:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

Where:

  • ( n n n ): Sequence length
  • ( d d d ): Feature dimension
  • ( d k d_k dk ): Dimension of queries/keys (typically ( d k = d / h d_k = d/h dk=d/h ) for multi-head attention with ( h h h ) heads)

Time Complexity Analysis

The computational bottlenecks of self-attention are:

  1. Computing ( Q K T Q K^T QKT ):
    The query matrix ( Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk ) is multiplied with the transposed key matrix ( K T ∈ R d k × n K^T \in \mathbb{R}^{d_k \times n} KTRdk×n ), producing an ( n × n n \times n n×n ) attention score matrix.
    Complexity: ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ).

  2. Softmax Operation:
    Softmax normalization is applied along each row of the ( n × n n \times n n×n ) attention matrix.
    Complexity: ( O ( n 2 ) O(n^2) O(n2) ).

  3. Computing Weighted Values:
    The ( n × n n \times n n×n ) attention scores are multiplied by the value matrix ( V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv ).
    Complexity: ( O ( n 2 d v ) O(n^2 d_v) O(n2dv) ).

Combining all these steps, the overall time complexity of self-attention is:

O ( n 2 d ) O(n^2 d) O(n2d)

When ( d d d ) is fixed (a constant), the complexity primarily depends on ( n n n ), making it quadratic.


Space Complexity

The attention score matrix ( Q K T Q K^T QKT ) has a size of ( n × n n \times n n×n ), requiring ( O ( n 2 ) O(n^2) O(n2) ) memory to store. This quadratic memory cost limits the model’s ability to handle long sequences.


2. Code Demonstration: Quadratic Complexity in Practice

The following code measures the computation time of self-attention as the input sequence length increases:

import torch
import time# Self-attention function
def self_attention(Q, K, V):attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))attention_weights = torch.softmax(attention_scores, dim=-1)output = torch.matmul(attention_weights, V)return output# Test different sequence lengths
def test_attention_complexity():d_k = 64  # Feature dimensionfor n in [128, 256, 512, 1024, 2048]:  # Sequence lengthsQ = torch.randn((1, n, d_k))  # QueryK = torch.randn((1, n, d_k))  # KeyV = torch.randn((1, n, d_k))  # Valuestart_time = time.time()output = self_attention(Q, K, V)end_time = time.time()print(f"Sequence Length: {n}, Time Taken: {end_time - start_time:.6f} seconds, Output Shape: {output.shape}")if __name__ == "__main__":test_attention_complexity()

Example Output

Sequence Length: 128, Time Taken: 0.001200 seconds, Output Shape: torch.Size([1, 128, 64])
Sequence Length: 256, Time Taken: 0.004500 seconds, Output Shape: torch.Size([1, 256, 64])
Sequence Length: 512, Time Taken: 0.015800 seconds, Output Shape: torch.Size([1, 512, 64])
Sequence Length: 1024, Time Taken: 0.065200 seconds, Output Shape: torch.Size([1, 1024, 64])
Sequence Length: 2048, Time Taken: 0.260000 seconds, Output Shape: torch.Size([1, 2048, 64])

From the output, it is clear that the computation time increases quadratically with the sequence length ( n ).


3. Solutions to Address the Quadratic Complexity

To address the inefficiency of quadratic complexity, several optimization methods have been proposed:

1. Low-Rank Approximation

Techniques like Linformer approximate the ( n × n n \times n n×n ) attention matrix using low-rank decomposition:

  • Complexity is reduced to ( O ( n k ) O(n k) O(nk) ), where ( k ≪ n k \ll n kn ).

2. Sparse Attention

Sparse attention mechanisms, such as Longformer and BigBird, compute attention only for selected tokens (e.g., local windows or global tokens):

  • Complexity is reduced to ( O ( n log ⁡ n ) O(n \log n) O(nlogn) ) or ( O ( n ) O(n) O(n) ).

3. Linear Attention

Linear attention, such as in Performer, uses kernel functions to approximate the attention mechanism, avoiding the ( Q K T Q K^T QKT ) operation:

  • Complexity becomes ( O ( n d ) O(n d) O(nd) ).

4. Blockwise and Sliding-Window Attention

Divide the input sequence into smaller chunks or sliding windows and compute attention locally within each block:

  • This approach significantly reduces the computational cost for long sequences.

4. Summary

The self-attention mechanism in Transformer models has a time and space complexity of ( O ( n 2 d ) O(n^2 d) O(n2d)), which grows quadratically with sequence length. This becomes a bottleneck for long input sequences, such as lengthy documents or DNA sequences.

Through our code example, we demonstrated the quadratic increase in computational time as the sequence length grows. To address this limitation, several optimizations—such as low-rank approximations, sparse attention, and linear attention—have been introduced to scale Transformers to longer sequences efficiently.

By understanding and leveraging these methods, we can improve the efficiency of self-attention and unlock the potential of Transformers for applications involving extremely long sequences.

后记

2024年12月17日22点26分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63230.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

模型 A/B测试(科学验证)

系列文章 分享 模型,了解更多👉 模型_思维模型目录。控制变量法。 1 A/B测试的应用 1.1 Electronic Arts(EA)《模拟城市》5游戏网站A/B测试 定义目标: Electronic Arts(EA)在发布新版《模拟城…

Java修饰符详解:从基础到高级用法

在Java编程语言中,有许多修饰符可以使用,它们大致可以分为两大类:访问控制修饰符、其他类型的修饰符。 这些修饰符主要用于指定类、方法或变量的特性,并且通常位于声明语句的开头部分。下面通过一些示例来进一步说明这一点&#…

onnx文件转pytorch pt模型文件

onnx文件转pytorch pt模型文件 1.onnx2torch转换及测试2.存在问题参考文献 从pytorch格式转onnx格式,官方有成熟的API;那么假如只有onnx格式的模型文件,该怎样转回pytorch格式? https://github.com/ENOT-AutoDL/onnx2torch提供了…

Git merge 和 rebase的区别(附图)

在 Git 中,merge 和 rebase 是两种用于整合分支变化的方法。虽然它们都可以将一个分支的更改引入到另一个分支中,但它们的工作方式和结果是不同的。以下是对这两者的详细解释: Git Merge 功能:合并分支,将两个分支的…

【Web】0基础学Web—js运算符、选择结构、循环结构

0基础学Web—js运算符、选择结构、循环结构 js运算符选择结构循环结构 js运算符 算术运算符: - * / %取余 赋值运算符: - * / % 单目运算符: i i --i i– 单独使用是自增1 或 自减1 如果被使用&#xff0c;先看到啥先操作啥 比较运算符&#xff1a; > 、 >、 < 、…

系列3:基于Centos-8.6 Kubernetes使用nfs挂载pod的应用日志文件

每日禅语 古代&#xff0c;一位官员被革职遣返&#xff0c;心中苦闷无处排解&#xff0c;便来到一位禅师的法堂。禅师静静地听完了此人的倾诉&#xff0c;将他带入自己的禅房之中。禅师指着桌上的一瓶水&#xff0c;微笑着对官员说&#xff1a;​“你看这瓶水&#xff0c;它已经…

tkdiff安装:Linux下文本对比工具

tkdiff在Linux下源码安装 1.下载解压2.编译安装3.配置环境变量4.验证及运行 本文&#xff0c;在Linux下使用源码安装tkdiff工具&#xff0c;以tkdiff-4.2版本为例&#xff0c;其他版本根据需要替换即可。 1.下载解压 去 http://sourceforge.net/projects/tkdiff/files/tkdiff…

耐蚀镍基合金的焊接技术与质量控制

耐蚀镍基合金是一类在腐蚀环境中具有优异性能的合金材料&#xff0c;广泛应用于化工、海洋工程、石油天然气等领域。其焊接技术与质量控制对于确保合金的使用性能和安全性至关重要。以下是对耐蚀镍基合金焊接技术与质量控制的详细探讨。 一、焊接技术 焊条选择 耐蚀镍基合金的焊…

Django REST framework(DRF)在处理不同请求方法时的完整流程

文章目录 一、POST 请求创建对象的流程二、GET 请求获取对象列表的流程三、GET 请求获取单个对象的流程四、PUT/PATCH 请求更新对象的流程五、自定义方法的流程自定义 GET 方法自定义 POST 方法 一、POST 请求创建对象的流程 请求到达视图层 方法调用&#xff1a; dispatch说明…

机器视觉与OpenCV--01篇

计算机眼中的图像 像素 像素是图像的基本单位&#xff0c;每个像素存储着图像的颜色、亮度或者其他特征&#xff0c;一张图片就是由若干个像素组成的。 RGB 在计算机中&#xff0c;RGB三种颜色被称为RGB三通道&#xff0c;且每个通道的取值都是0到255之间。 计算机中图像的…

qemu源码解析【03】qom实例

目录 qemu源码解析【03】qom实例arm_sbcon_i2c实例 qemu源码解析【03】qom实例 arm_sbcon_i2c实例 以hw/i2c/arm_sbcon_i2c.c代码为例&#xff0c;这个实例很简单&#xff0c;只用100行左右的代码&#xff0c;调用qemu系统接口实现了一个i2c硬件模拟先看include/hw/i2c/arm_s…

小程序自定义tab-bar,踩坑记录

从官方下载代码 https://developers.weixin.qq.com/miniprogram/dev/framework/ability/custom-tabbar.html 1、把custom-tab-bar 文件放置 pages同级 修改下 custom-tab-bar 下的 JS文件 Component({data: {selected: 0,color: "#7A7E83",selectedColor: "#3…

操作系统(14)请求分页

前言 操作系统中的请求分页&#xff0c;也称为页式虚拟存储管理&#xff0c;是建立在基本分页基础上&#xff0c;为了支持虚拟存储器功能而增加了请求调页功能和页面置换功能的一种内存管理技术。 一、基本概念 分页&#xff1a;将进程的逻辑地址空间分成若干个大小相等的页&am…

git企业开发的相关理论(一)

目录 一.初识git 二.git的安装 三.初始化/创建本地仓库 四.配置用户设置/配置本地仓库 五.认识工作区、暂存区、版本库 六.添加文件__场景一 七.查看 .git 文件/添加到本地仓库后.git中发生的变化 1.执行git add后的变化 index文件&#xff08;暂存区&#xff09; log…

wxpython图形用户界面编程

wxpython图形用户界面编程 一、wxpython的基础 1.1 wxpython的基础 作为图形用户界面开发工具包 wxPython&#xff0c;主要提供了如下 GUI 内容&#xff1a; 窗口。控件。事件处理。布局管理。 1.2 wxpython的类层次机构 1.3 wxpython的安装 Windows 和 macOS 平台安装&a…

水仙花数(流程图,NS流程图)

题目&#xff1a;打印出所有的100-999之间的"水仙花数"&#xff0c;并画出流程图和NS流程图。所谓"水仙花数"是指一个三位数&#xff0c;其各位数字立方和等于该数本身。例如&#xff1a;153是一个"水仙花数"&#xff0c;因为1531的三次方&#…

不配置python环境,直接用PyCharm就可以?

有的伙伴可能遇到不安装python环境只安装pycharm也可以进行运行代码。 所以自认为是不需要解释器就可以运行&#xff1f; 这个是不现实的&#xff0c;有很多伙伴可能是安装了Pycharm&#xff0c;但Pycharm看你电脑上没有解释器&#xff0c;所以在安装的时候给你默认安装在C盘…

网络安全渗透测试概论

渗透测试&#xff0c;也称为渗透攻击测试是一种通过模拟恶意攻击者的手段来评估计算机系统、网络或应用程序安全性的方法。 目的 旨在主动发现系统中可能存在的安全漏洞、脆弱点以及潜在风险&#xff0c;以便在被真正的恶意攻击者利用之前&#xff0c;及时进行修复和加固&…

爬虫数据能用于商业吗?

在当今数字化时代&#xff0c;数据已成为企业获取竞争优势的关键资源。网络爬虫作为一种数据收集工具&#xff0c;能够从互联网上抓取大量数据&#xff0c;这些数据在商业分析中扮演着重要角色。然而&#xff0c;使用爬虫技术获取的数据是否合法、能否用于商业分析&#xff0c;…

前端面试汇总(不定时更新)

目录 HTML & CSS1. XML、HTML、XHTML 有什么区别&#xff1f;⭐2. XML和JSON的区别&#xff1f;3. 是否了解W3C的规范&#xff1f;⭐4. 什么是语义化标签&#xff1f;⭐⭐5. 行内元素和块级元素的区别&#xff1f;⭐6. 行内元素和块级元素的转换&#xff1f;⭐7. 常用的块级…