【大模型上下文长度扩展】FlashAttention:高效注意力计算的新纪元

FlashAttention:高效注意力计算的新纪元

    • 核心思想
      • 核心操作融合,减少高内存读写成本
      • 分块计算(Tiling),避免存储一次性整个矩阵
      • 块稀疏注意力,处理长序列时的效率问题
      • 利用快速 SRAM,处理内存与计算速度不匹配
      • 算术强度优化,处理计算与内存访问的不平衡
      • 重计算,解决后向传递中存储大型中间矩阵的需求
    • 当前FlashAttention实现的局限性,并提出了未来发展的方向
      • 低级语言编程的复杂性
      • IO-感知优化的普遍性
      • 多GPU并行计算的IO优化

 


论文:https://arxiv.org/pdf/2205.14135.pdf

 

核心思想

FlashAttention 提出的是为了解决 Transformers 在处理长序列时的速度慢和内存消耗大的问题。

这个问题主要是因为,自注意力模块在长序列上的时间和内存复杂度都是二次方的。

FlashAttention的本质是通过创新的算法设计,实现了对Transformer模型中注意力机制的高效计算。

  • FlashAttention通过减少HBM访问次数和避免存储大型中间矩阵,使BERT模型比MLPerf 1.1的速度记录快15%,GPT-2的训练速度提高了最高3倍。

  • 使用FlashAttention的GPT-2模型,在4K的上下文长度下训练比Megatron在1K上下文长度下训练还快,同时困惑度(perplexity)更低,说明模型质量提高。

  • FlashAttention在常见序列长度(最高2K)上比标准注意力实现快3倍,并且其内存占用随序列长度线性增长,证明了其在效率和内存使用上的优势。

  • 块稀疏FlashAttention通过仅计算重要的注意力块来减少计算量和内存使用,使得Transformer模型能够处理高达64K序列长度,且在Path-256任务上达到了63.1%的准确率,显示了其在处理长序列任务上的能力。

它通过以下核心方法和策略,解决了传统注意力计算在长序列处理时遇到的速度慢和内存消耗大的问题:

  1. IO-感知优化:FlashAttention深入考虑了GPU内存层次之间的交互,特别是高带宽内存(HBM)与片上SRAM之间的读写操作,通过优化这些操作来减少内存访问成本,从而提高计算效率。

  2. 分块计算(Tiling):通过将输入序列分成小块并逐块处理,FlashAttention避免了一次性加载整个序列到内存中,减轻了内存压力,并使得注意力计算更加高效。

  3. 重计算策略:为了减少后向传播时对大型中间矩阵的存储需求,FlashAttention采用了在需要时重新计算这些矩阵的策略,从而节省了大量的内存空间。

  4. 核心融合:FlashAttention通过将多个计算步骤融合到一个CUDA核心中执行,减少了内存访问次数,并提高了执行速度。

这些策略共同作用,使FlashAttention能够以更少的内存访问和更低的时间复杂度,准确地计算出注意力,从而在保持模型质量的同时,显著提高了训练速度和效率。

此外,FlashAttention的设计还支持块稀疏注意力,进一步提高了处理长序列能力,使得在资源有限的情况下,Transformer模型能够处理更长的上下文信息,这在自然语言处理和其他需要长序列处理的领域中尤为重要。

FlashAttention本质上是对传统Transformer注意力机制的一个高效、内存友好的改进,它通过深入挖掘和优化计算机内存和计算资源的使用方式,推动了深度学习模型在复杂任务上的应用和发展。

 

核心操作融合,减少高内存读写成本

  • 子解法: IO-感知算法(IO-Awareness)
    • 解释: 传统的注意力算法没有考虑到 GPU 内存层次之间的读写成本,导致了大量的内存访问,进而增加了计算时间和内存消耗。
    • FlashAttention 通过考虑 IO,即输入/输出操作,特别是在 GPU 高带宽存储器(HBM)与 GPU 上的 SRAM 之间的读写操作,来降低这些成本。
    • 例子: 在传统的 Transformer 模型中,整个注意力矩阵需要从 HBM 读入到 SRAM 中进行计算,
    • 然后结果再写回 HBM,这个过程中的读写操作非常耗时和耗内存。
    • FlashAttention 通过减少这种读写操作的次数,来减少内存访问成本。

在标准注意力计算中,每个操作(如 softmax、矩阵乘法等)都需要从 HBM 读取输入,计算后再将结果写回 HBM,导致高内存访问成本。

如果我们可以将多个操作合并为一个操作(核心融合),那么输入只需从 HBM 加载一次,这样就减少了内存访问次数,从而降低了内存访问成本。
 

分块计算(Tiling),避免存储一次性整个矩阵

  • 子解法: 增量式 softmax 计算(Tiling)
    • 解释: 标准的注意力机制需要存储整个注意力矩阵以便于后向传播,这在长序列上是非常内存消耗的。
    • FlashAttention 通过将输入分块(tiling)并多次通过输入块逐步执行 softmax 减少(也称为 tiling),避免了一次性处理整个大矩阵。
    • 例子: 假设有一个很长的序列,传统方法需要一次性计算和存储整个序列的注意力矩阵。
    • FlashAttention 则将序列分成小块,每次只处理一个块,并逐步累积计算结果,从而不需要存储整个大矩阵。

在标准注意力机制中,整个注意力矩阵需要一次性计算并存储,导致对 HBM 的大量访问。

通过将输入矩阵 Q、K、V 分块并逐块计算,我们可以逐步生成注意力输出,减少了一次性对大量数据的访问需求。

一个大型矩阵乘法,通过将矩阵分为小块,每次只处理一部分数据,就可以减少内存的即时需求。

 

块稀疏注意力,处理长序列时的效率问题

  • 子解法: 块稀疏注意力(Block-sparse Attention)
    • 解释: 长序列上的注意力计算复杂度高,导致计算缓慢。
    • FlashAttention 引入了块稀疏技术,通过只计算序列中重要部分的注意力,忽略其他不重要的部分,从而减少计算量。
    • 例子: 在处理一个长文本时,可能只有部分词语之间存在强关联,而其他词语的关联性较弱。块稀疏注意力允许模型只关注那些重要的词语间的关联,忽略其他,从而加速计算并降低内存使用。

 

利用快速 SRAM,处理内存与计算速度不匹配

  • 子解法: 利用快速 SRAM
    • 原因: 现代 GPU 的计算速度相比内存速度增长得更快,使得大多数操作成为内存访问受限。
    • 例子: 通过更多地利用每个流式多处理器上的快速 SRAM(与 HBM 相比,SRAM 速度快得多但容量小得多),我们可以加速那些内存访问受限的操作,例如通过在 SRAM 中计算部分结果来减少对 HBM 的访问。

 

算术强度优化,处理计算与内存访问的不平衡

  • 子解法: 算术强度优化
    • 原因: 操作可以根据计算和内存访问之间的平衡被分类为计算密集型或内存访问密集型。
    • 标准注意力实现中,很多操作(如 softmax)是内存访问密集型的。
    • 例子: 通过优化算术强度,即每字节内存访问的算术操作数量,我们可以尽量将操作转变为计算密集型,从而减轻内存访问的瓶颈。

 

重计算,解决后向传递中存储大型中间矩阵的需求

  • 子解法: 重计算(Recomputation)
    • 原因: 标准实现中,后向传递需要访问前向传递计算时产生的大型中间矩阵(如 S 和 P 矩阵)。通过存储必要的统计量而非整个矩阵,并在需要时重计算这些矩阵,可以避免大量的内存使用。
    • 例子: 类似于梯度检查点技术,我们不存储整个计算过程中的中间状态,而是仅存储关键节点,需要时再重建整个状态。

 

通过子解法的组合,FlashAttention 成功地解决了 Transformers 在处理长序列时速度慢和内存消耗大的问题。

FlashAttention 提出了一种计算精确注意力的算法,其关键在于通过减少对高带宽内存(HBM)的读写操作以及避免在后向传递中存储大型中间矩阵,从而实现了既节省内存又加速计算的目标。

在探索传统注意力机制在现代硬件(尤其是 GPU)上的执行效率时,遇到了一系列的具体问题,这些问题导致了处理速度慢和高内存消耗。

每种解决方案都直接针对了标准注意力实现中的效率瓶颈,通过改善内存访问模式、减少不必要的内存写入和读取、以及优化计算流程来提高整体性能。

 
在这里插入图片描述

左侧:展示了在GPU中的内存层次结构和FlashAttention如何在这种结构中工作。

它说明了:

  • GPU的不同内存层次及其带宽和大小,包括片上SRAM(20MB, 19TB/s),高带宽内存HBM(40GB, 1.5TB/s),以及主内存DRAM(12.8GB/s, 大于1TB)。
  • FlashAttention使用分块计算(Tiling)来避免实现大型 N×N 注意力矩阵。
  • 在外部循环(红色箭头)中,FlashAttention遍历K和V矩阵的块,并将它们加载到快速的片上SRAM中。
  • 在每个块中,FlashAttention遍历Q矩阵的块(蓝色箭头),加载到SRAM中,并将注意力计算的输出写回到HBM。

右侧:显示了使用PyTorch实现的注意力计算与FlashAttention实现在GPT-2模型上的速度对比。

它说明了:

  • FlashAttention与PyTorch实现相比在各个组件(矩阵乘法、Dropout、Softmax、Mask和Fused Kernel)上的时间消耗。
  • FlashAttention没有读写大型 N×N 注意力矩阵到HBM,因此在注意力计算上得到了约7.6倍的加速。

 


当前FlashAttention实现的局限性,并提出了未来发展的方向

 

低级语言编程的复杂性

  • 子解法1: 高级语言到CUDA的自动编译
    • 原因: 目前,IO-感知的注意力实现需要在CUDA中手动编写新的核函数,这不仅需要在比PyTorch这样的高级语言更低级的语言中编程,而且还需要大量的工程努力。
    • 例子: 类似于图像处理领域的Halide工具,可以让研究人员用高级语言编写算法,然后自动编译成优化的CUDA代码,减少直接使用CUDA编程的复杂性。

 

IO-感知优化的普遍性

  • 子解法2: 扩展IO-感知实现到其他模块
    • 原因: 虽然注意力计算是Transformer模型中最耗内存的部分,但模型的每一层都需要与GPU的高带宽内存(HBM)交互。
    • 例子: 在深度学习模型的其他组件,如卷积层或循环层,也采用IO-感知的实现方法,可以进一步提高整个模型的效率。

 

多GPU并行计算的IO优化

  • 子解法3: 多GPU间的IO-感知方法
    • 原因: FlashAttention的当前实现在单GPU上是最优的,但注意力计算可以跨多GPU并行化,这引入了考虑GPU间数据传输的额外IO分析层。
    • 例子: 通过设计能够优化GPU间数据传输的IO-感知算法,可以在不牺牲性能的前提下,实现更大规模的模型训练和更高效的并行计算。

从提高开发效率、扩展IO-感知优化的应用范围,到优化多GPU并行计算的效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/672161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大模型上下文长度扩展】LongQLoRA:单GPU(V100)环境下的语言模型优化方案

LongQLoRA 核心问题子问题1: 预定义的上下文长度限制子问题2: 训练资源的需求高子问题3: 保持模型性能分析不足 LongQLoRA方法拆解子问题1: 上下文长度限制子问题2: 高GPU内存需求子问题3: 精确量化导致的性能损失分析不足效果 论文:https://arxiv.org/pdf/2311.048…

docker镜像结构

# 基础镜像 FROM openjdk:11.0-jre-buster # 设定时区 ENV TZAsia/Shanghai RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # 拷贝jar包 COPY docker-demo.jar /app.jar # 入口 ENTRYPOINT ["java", "-jar"…

游泳耳机推荐性价比排行榜,四大高性价比游泳耳机推荐

随着运动健康意识的提高,越来越多的朋友选择在游泳馆进行锻炼。然而,在水中享受音乐并非易事,这就需要一款真正防水的耳机。尽管市面上有许多声称具备防水功能的耳机产品,但实际使用中往往难以达到理想的防水效果。为了帮助大家找…

之前看过的前序遍历的线索二叉树感觉写的有点问题 这里更新一下我的思路

前序线索化 #include<iostream> using namespace std;typedef int datatype; typedef struct BitNode {datatype Data;struct BitNode* leftchild;struct BitNode* rightchild;int lefttag;int righttag; }Node; #pragma region 前序线索化递归遍历 Node* previous NUL…

maven依赖报错处理(或者maven怎么刷新都下载不了依赖)

maven依赖报错&#xff0c;或者不报错&#xff0c;但是怎么刷新maven都没反应&#xff0c;可以试一下以下操作 当下载jar的时候&#xff0c;如果断网&#xff0c;或者连接超时的时候&#xff0c;会自动在文件夹中创建一个名为*lastupdate的文件&#xff0c;当有了这个文件之后…

网络工程师专属春节对联,不要太真实了!

中午好&#xff0c;我的网工朋友。 都放假了没&#xff1f;龙年将至&#xff0c;都有啥新年计划&#xff1f; 过年&#xff0c;讲究的就是一个热闹&#xff0c;可以暂时告别辛苦的一年&#xff0c;重新整装出发。 热闹可少不了春联啊&#xff0c;红红火火又一年&#xff0c;…

Vue源码系列讲解——虚拟DOM篇【一】(Vue中的虚拟DOM)

目录 1. 前言 2. 虚拟DOM简介 2.1什么是虚拟DOM&#xff1f; 2.2为什么要有虚拟DOM&#xff1f; 3. Vue中的虚拟DOM 3.1 VNode类 3.2 VNode的类型 3.2.1 注释节点 3.2.2 文本节点 3.2.3 克隆节点 3.2.4 元素节点 3.2.5 组件节点 3.2.6 函数式组件节点 3.2.7 小结 3…

OpenCV-31 获得形态学卷积核

OpenCV提供了获取卷积核的API&#xff0c;不需要我们手动创建卷积核。 通过下面API---getStructuringElement(shape&#xff0c;ksize&#xff0c;[, anchor]) shape是指卷积核的型状&#xff0c;注意不是指长宽&#xff0c;是指卷积核中1形成的形状。MORPH_RECT 卷积核中的1…

(三)elasticsearch 源码之启动流程分析

https://www.cnblogs.com/darcy-yuan/p/17007635.html 1.前面我们在《&#xff08;一&#xff09;elasticsearch 编译和启动》和 《&#xff08;二&#xff09;elasticsearch 源码目录 》简单了解下es&#xff08;elasticsearch&#xff0c;下同&#xff09;&#xff0c;现在我…

SPSS基础操作:对数据进行加权处理

对数据进行加权处理是我们使用SPSS提供某些分析方法的重要前提。对数据进行加权后&#xff0c;当前的权重将被保存在数据中。当进行相应的分析时&#xff0c;用户无须再次进行加权操作。本节以对广告的效果观测为例&#xff0c;讲解数据的加权操作。本例给出了消费者购买行为与…

Arthas使用教程—— 阿里开源线上监控诊断产品

文章目录 1 简介2背景3 图形界面工具 arthas 阿里开源3.1 &#xff1a;启动 arthas3.2 help :查看arthas所有命令3.3 查看 dashboard3.4 thread 列出当前进程所有线程占用CPU和内存情况3.5 jvm 查看该进程的各项参数 &#xff08;类比 jinfo&#xff09;3.6 通过 jad 来反编译 …

端口扫描神器:御剑 保姆级教程(附链接)

一、介绍 御剑&#xff08;YooScan&#xff09;是一款网络安全工具&#xff0c;主要用于进行端口扫描。它具有直观的用户界面&#xff0c;方便用户进行端口扫描和信息收集。以下是御剑端口扫描工具的一些主要特点和功能&#xff1a; 图形用户界面&#xff1a; 御剑提供直观的图…

告别mPDF迎来TCPDF和中文打印遇到的问题

mPDF是一个用PHP编写的开源PDF生成库。它最初由Claus Holler创建&#xff0c;于2004年发布。原来用开源软件打印中文没有问题&#xff0c;最近发现新的软件包中mPDF被TCPDF代替了&#xff0c;当然如果只用西文的PDF是没有发现问题&#xff0c;但要打印中文就有点抓瞎了如图1&am…

我的PyTorch模型比内存还大,怎么训练呀?

原文&#xff1a;我的PyTorch模型比内存还大&#xff0c;怎么训练呀&#xff1f; - 知乎 看了一篇比较老&#xff08;21年4月文章&#xff09;的不大可能训练优化方案&#xff0c;保存起来以后研究一下。 随着深度学习的飞速发展&#xff0c;模型越来越臃肿&#xff0c;哦不&a…

vue element 组件 form深层 :prop 验证失效问题解决

此图源自官网 借鉴。 当我们简单单层验证的时候发现是没有问题的&#xff0c;但是有的时候可能会涉及到深层prop&#xff0c;发现在去绑定的时候就不生效了。例如我们在form单里面循环验证&#xff0c;在去循环数据验证。 就如下图的写法了 :prop"pumplist. i .device…

Redis缓存设计及优化

缓存设计 缓存穿透 缓存穿透是指查询一个根本不存在的数据&#xff0c; 缓存层和存储层都不会命中&#xff0c; 通常出于容错的考虑&#xff0c; 如果从存储层查不到数据则不写入缓存层。 缓存穿透将导致不存在的数据每次请求都要到存储层去查询&#xff0c; 失去了缓存保护后…

Pandas 对带有 Multi-column(多列名称) 的数据排序并写入 Excel 中

Pandas 从Excel 中读取带有 Multi-column的数据 正文 正文 我们使用如下方式写入数据&#xff1a; import pandas as pd import numpy as npdf pd.DataFrame(np.array([[10, 2, 0], [6, 1, 3], [8, 10, 7], [1, 3, 7]]), columns[[Number, Name, Name, ], [col 1, col 2, co…

数据结构——C/栈和队列

&#x1f308;个人主页&#xff1a;慢了半拍 &#x1f525; 创作专栏&#xff1a;《史上最强算法分析》 | 《无味生》 |《史上最强C语言讲解》 | 《史上最强C练习解析》 &#x1f3c6;我的格言&#xff1a;一切只是时间问题。 ​ 1.栈 1.1栈的概念及结构 栈&#xff1a;一种特…

WPF是不是垂垂老矣啦?平替它的框架还有哪些

WPF&#xff08;Windows Presentation Foundation&#xff09;是微软推出的一种用于创建 Windows 应用程序的用户界面框架。WPF最初是在2006年11月推出的&#xff0c;它是.NET Framework 3.0的一部分&#xff0c;为开发人员提供了一种基于 XAML 的方式来构建丰富的用户界面。 W…

你的代码很丑吗?试试这款高颜值代码字体

Monaspace 是有 GitHub 开源的代码字体&#xff0c;包含 5 种变形字体的等宽代码字体家族&#xff0c;颜值 Up&#xff0c;很难不喜欢。 来看一下这 5 种字体分别是&#xff1a; 1️⃣ Radon 手写风格字体 2️⃣ Krypton 机械风格字体 3️⃣ Xenon 衬线风格字体 4️⃣ Argon…