FlashAttention算法详解

这篇文章的目的是详细的解释Flash Attention,为什么要解释FlashAttention呢?因为FlashAttention 是一种重新排序注意力计算的算法,它无需任何近似即可加速注意力计算并减少内存占用。所以作为目前LLM的模型加速它是一个非常好的解决方案,本文介绍经典的V1版本,最新的V2做了其他优化我们这里暂时不介绍。因为V1版的FlashAttention号称可以提速5-10倍,所以我们来研究一下它到底是怎么实现的。

介绍

论文的标题是:

“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”

内存的效率与普通注意力相比(序列长度是二次的,O(N²)),FlashAttention是次二次的/线性的N (O(N))。并且它不是注意力机制的近似值(例如,稀疏或低秩矩阵近似值方法)-它的输出与“传统”注意力机制相同。与普通的注意力相比,FlashAttention的注意力是”有感知“的。

它利用底层硬件的内存层次知识(例如gpu,但其他AI加速器也应该工作,我这里使用gpu作为示例)。一些[近似]方法在序列长度上将计算要求降低到线性或近线性,但其中许多方法专注于减少FLOP,而忽略内存访问(IO)的开销。

经过多年的发展gpu的FLOPS的增长速度一直在以比内存吞吐量(TB/s)更快。内存的瓶颈应该引起重视。FLOPS和内存吞吐量需要紧密结合,由于硬件上的差距,我们就需要软件层面上的工作进行平衡。

根据计算和内存访问之间的比率,操作可以分为以下两种:

  • 计算约束 :矩阵乘法
  • 内存约束:元素操作(激活,dropout,masking),归并操作(softmax, layer norm,sum等)

在当前的AI加速器(GPU)上是受内存大小限制的。因为它“主要由元素操作组成”,或者更准确地说,注意力的算术密度不是很高。

我们看看这个图:

可以看到,masking,softmax和dropout是占用大量时间的操作,而不是矩阵乘法(即使大部分FLOPS是在matmul中)。内存不是一个单一的工件,它在本质上是分层的,一般的规则是:内存越快,越昂贵,容量越小。

我们在上面说的,FlashAttention的注意力是”有感知“的可以归结为利用SRAM比HBM(高带宽内存)快得多来确保减少两者之间的通信。

以A100为例:

A100 GPU有40-80GB的高带宽内存(HBM),带宽为1.5-2.0 TB/s,而每108个流处理器有192KB的SRAM,带宽估计在19TB/s左右。

可以看到大小小了很多,但是速度却提升了10倍,所以如何高效的利用SRAM是提速的关键,让我们看看标准注意力实现背后的计算:

标准实现如何显示对HW操作方式不大尊重。它基本上将HBM加载/存储操作视为0成本(它不是“io感知”)。

我们首先考虑如何使这个实现更有效(时间和内存方面)。最简单的方法是删除冗余的HBM读/写。

如何把S写回HBM只是为了(重新)加载它来计算softmax,那么我们可以将其保存在SRAM中,执行所有中间步骤,然后将最终结果写回HBM。

内核基本上是“GPU操作”的一种奇特的说法(参考我们以前发布的CUDA入门,往简单了说就是一个函数)。融合则可以将多个操作融合在一起。所以只从HBM加载一次,执行融合的op,然后将结果写回来。这样做可以减少通信开销。

这里还有一个专业名词术语是“materialization”(物化/实体化)。它指的是,在上面的标准注意力实现中,已经分配了完整的NxN矩阵(S, P)。下面我们将看到如何直接将内存复杂度从O(N²)降低到O(N)。

Flash attention基本上可以归结为两个主要观点:

Tiling (在向前和向后传递时使用)-基本上将NxN softmax/scores矩阵分块成块。

Recomputation (仅在向后传递中使用)

算法如下:

上面我们提到了很多名词,你可能还不了解。没关系下面我们开始逐行解释算法。

FlashAttention算法

让Tiling方法的主要障碍是softmax。因为softmax需要将所有的分数列耦合在一起。

看到分母了吗?这就是问题所在。

要计算输入序列中的特定第i个标记对序列中其他标记的关注程度,需要在SRAM中随时可用所有这些分数(这里用z_j表示)。

但是SRAM的容量是有限的。N(序列长度)可以是1000甚至100000个令牌。所以N²爆炸得很快。所以论文使用了一个技巧:把softmax的计算分成更小的块,最终仍然得到完全相同的结果。

我们可以只获取前一个B分数(x_1到x_B)并为它们计算softmax。然后通过迭代,“收敛”到正确的结果。以一种聪明的方式组合这些每块部分softmax的数字,这样最终的结果实际上是正确的。方法如下:

基本上,为了计算属于前2个块(大小为B)的分数的softmax,必须要跟踪每个块的2个统计数据:m(x)(最大分数)和l(x) (exp分数总和)。然后就可以用归一化系数将它们无缝地融合在一起。

这里主要是一些基本的代数运算,通过展开f(x)和l(x)项并与e^x相乘一些项会相互抵消,这里就不写了。这个逻辑递归地一直持续到最后一个(N/B)块,这样就得到了N维正确的softmax输出!

为了详细的介绍这个算法,假设有一个大小为1的批处理(即单个序列)和单个注意力头,稍后会扩展它(通过简单地跨GPU的并行化-稍后会详细介绍)。我们暂时忽略了dropout和masking,因为稍后再添加。

我们开始计算:

初始化:HBM的容量以GB为单位测量(例如RTX 3090有24 GB的VRAM/HBM, A100有40-80 GB等),因此分配Q, K和V不是问题。

第1步

计算行/列块大小。为什么ceil(M / 4 d) ?因为查询、键和值向量是d维的,所以我们还需要将它们组合成输出的d维向量。所以这个大小基本上允许我们用q k v和0个向量最大化SRAM的容量。

比如说,假设M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载50个q, k, v, o个向量的块,这样可以减少HBM/SRAM之间的读/写次数。

对于B_r,我也不太确定他们为什么要用d执行最小运算?如果有人知道,请评论指教!

第2步:

用全0初始化输出矩阵O。它将作为一个累加器,l也类似它的目的是保存softmax的累积分母——exp分数的总和)。M(保存逐行最大分数)初始化为-inf,因为我们将对其进行Max运算符,因此无论第一个块的Max是什么-它肯定大于-inf 。

第3步:

步骤1中的块大小将Q, K和V分成块。

第4步:

将O, l, m分割成块(与Q的块大小相同)。

第5步:

开始跨列循环,即跨键/值向量(上图中的外部循环)。

第6步:

将K_j和V_j块从HBM加载到SRAM。在这个时间点上我们仍然有50%的SRAM未被占用(专用于Q和O)。所以SRAM是这样的:

第7步:

开始跨行内部循环,即跨查询向量。

第8步:

将Q_i (B_r x d)和O_i (B_r x d)块以及l_i (B_r)和m_i (B_r)加载到SRAM中。

这里需要保证l_i和m_i能够载入SRAM(包括所有中间变量),这块可能是CUDA的知识,我不太确定如何计算,所以如果你有相关的信息,请留言

第9步:

计算Q_i (B_r x d)和K_j转置(d x B_c)之间的点积,得到分数(B_r x B_c)。并没有将整个nxns(分数)矩阵“物化”。

假设外部循环索引为j (j=3),内部循环索引为i (i=2), N为25,块大小为5,下面就是刚刚计算的结果(假设以1为基础的索引):

也就是输入序列中标记11-15的标记6-10的注意力得分。这里的一个要点是,这些都是精确的分数,它们永远不会改变。

第10步:

使用上一步计算的分数计算m_i_j、l*i_j和P~*i_j。M ~_i_j是按行计算的,找到上面每一行的最大元素。

然后通过应用元素运算得到P~_i_j:

归一化-取行最大值并从行分数中减去它,然后EXP

l~_i_j是矩阵P的逐行和。

第11步:

计算m_new_i和l_new_i。同样非常简单,可以重复使用上面的图表:

M_i包含之前所有块的逐行最大值(j=1 & j=2,用绿色表示)。M _i_j包含当前块的逐行最大值(用黄色表示)。为了得到m_new_i我们只需要在m_i_j和m_i之间取一个最大值,l_new_i也类似。

第12步(最重要):

这是算法中最难的部分。

它允许我们用矩阵的形式做逐行标量乘法。如果你有一列标量s (N)和一个矩阵a (NxN)如果你做diag(s)* a你基本上是在用这些标量做a行的元素乘法。

公式1(为了方便再次粘贴在这里):

第12步的第一项所做的(用绿色下划线)是:更新了在同一行块中当前块之前的块的当前softmax估计。如果j=1(这是这一行的第一个块。

第一项乘以diag(l_i)是为了抵消之前迭代中除以的相同常数(这个常数隐藏在O_i中)。

表达式的第二项(黄色下划线)是不需要消去的,因为可以看到我们直接将P~_i_j矩阵与V向量块(V_j)相乘。

e^x项是用来修改矩阵P~_i_j & O_i的,方法是消去前一次迭代中的m,用最新的估计(m_new_i)来更新它,该估计包含到目前为止逐行最大值。

以下是我的逐步分析(实际上只需要5分钟,希望能有所帮助!)

重点是这些外面的e项和P/O矩阵里面的e项消掉了,所以总是得到最新的m_new_1估计!

第三次迭代也是类似的,得到了正确的最终结果!

回想一下:这只是对最终O_i的当前估计。只有在我们遍历上图中的所有红色块之后,我们才能最终得到确切的结果。

第13步

将最新的累加到统计数据(l_i & m_i)写回HBM。注意它们的维数是B_r。

第13、14、15、1步

嵌套的for循环结束,O (Nxd)将包含最终结果:每个输入令牌的注意力加权值向量!

简单汇总

算法可以很容易地扩展到“block-sparse FlashAttention”,这是一种比FlashAttention快2-4的稀疏注意力算法,扩展到64k的序列长度!通过使用一个块形式的掩码矩阵,可以跳过上面嵌套的for循环中的某些加载/存储,这样我们可以按比例节省稀疏系数,比如下图

现在让我们简单地讨论一下复杂性。

复杂度分析

空间:在HBM中分配了Q, K, V, O (Nxd), l和m (N)。等于4Nd + 2*N。去掉常量,并且知道d也是一个常量并且通常比N小得多(例如d={32,64,128}, N={1024,…,100k}),可以得到O(N)的空间,这有助于扩展到64k序列长度(再加上一些其他“技巧”,比如ALiBi)。

时间:这里不会严格地进行时间复杂度分析,但是我们将使用一个好的指标:HBM访问的数量。

论文的解释如下:

他们是怎么得到这个数字的?让我们来分析嵌套的for循环:

我们的块大小是M/4d。这意味着向量被分割成N/(M/4d)块。取它的2次方(因为要遍历行/列块)得到O(N²d²/ M²)

我们不能一次获取整个块,如果做一个大O分析,可能会让我们认为这并不比标准注意力好多少,但对于典型的数字,这导致访问次数减少了9倍(根据上面的论文截图)。

我们的伪算法集中在一个单头注意力,假设批处理大小为1。下面我们就开始进行扩展了

多头注意力

要扩展到batch_size > 1和num_heads > 1实际上并不难。

算法基本上是由单个线程块(CUDA编程术语)处理的。这个线程块在单个流多处理器(SM)上执行(例如,A100上有108个这样的处理器)。为了并行化计算,只需要在不同的SMs上并行运行batch_size * num_heads线程块。该数字与系统上可用的SMs数量越接近,利用率就越高(理想情况下是多个,因为每个SM可以运行多个线程块)。

反向传播

对于GPU内存的占用,另外一个大头就是反向传播,通过存储输出O (Nxd)和softmax归一化统计数据(N),我们可以直接从SRAM中的Q, K和V (Nxd)块中反向计算注意力矩阵S (NxN)和P (NxN) !从而使内存保持在O (N)。这个比较专业了,我们了解以下就可以了,所以需要详细的内容请看原论文。

代码实现

最后,让我们看看在使用flash attention时可能出现的一些问题。因为涉及到显存的操作,所以我们只能深入CUDA,但是CUDA又比较复杂。

这就是OpenAI的Triton等项目的优势(参见他们的FlashAttention实现)。Triton基本上是一种DSL(领域特定语言),介于CUDA和其他领域特定语言(例如TVM)之间的抽象级别。可以编写超级优化的Python代码(一旦编译),而不必直接处理CUDA。这样Python代码可以部署在任意的加速器上(这是Triton任务)。

另外一个好消息是Triton最近已经与PyTorch 2.0集成了。

另外对于某些用例,比如对于超过1K的序列长度,一些近似注意方法(如Linformer)开始变得更快。但是flash attention的块稀疏实现优于所有其他方法。

总结

你有没有想过,对于这种底层优化的算法为什么是一个斯坦福大学的学生发布,而不是NVIDIA的工程师?

我认为有2种可能的解释:

1、FlashAttention更容易/只能在最新的gpu上实现(原始代码库不支持V100)。

2、通常“局外人”是那些以初学者的眼光看待问题,能够看到问题的根源并从基本原则出发解决问题

最后我们还是要进行个总结

FlashAttention能够让BERT-large训练中节省15%,将GPT训练速度提高2/3,并且是在不需要修改代码的情况下,这是一个非常重要的进步,它为LLM的研究又提出了一个新的方向。

论文地址:

https://avoid.overfit.cn/post/9d812b7a909e49e6ad4fb115cc25cdc1

作者:Aleksa Gordić

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/47528.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

拒绝摆烂!C语言练习打卡第五天

🔥博客主页:小王又困了 📚系列专栏:每日一练 🌟人之为学,不日近则日退 ❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、选择题 📝1.第一题 📝2.第二题 &#x1f4d…

Linux的基础指令

目录 1、ls指令 .和..意义 2、pwd指令 3、cd指令 ①cd ~ ②cd - 关于cd ..的用法 绝对路径和相对路径 4、touch指令 5、mkdir指令 tree指令 6、rmdir指令 7、rm指令 * 8、man指令 9、cp指令 nano: 10、mv指令 11、cat指令 12、more指令 13、less…

0009Java程序设计-jsp在线学习平台设计与实现

摘 要目 录系统实现开发环境 摘 要 在线学习平台,是一个利用因特网作为平台传送教学内容,实施网上教学,进行网上交流和学习的信息系统。构建在线学习系统平台,可以克服传统课堂教育的局限性,形成一种主动的、协作的、…

[线程/C++]线程同(异)步和原子变量

文章目录 1.线程的使用1.1 函数构造1.2 公共成员函数1.2.1 get_id()1.2.2 join()2.2.3 detach()2.2.5 joinable()2.2.6 operator 1.3 静态函数1.4 call_once 2. this_thread 命名空间2.1 get_id()2.2 sleep_for()2.3 sleep_until()2.4 yield() 3. 线程同步之互斥锁3.1 std:mute…

c#中lambda表达式缩写推演

Del<string> ml new Del<string>(Notify);//泛型委托的实例化&#xff0c;并关联Nofity方法 Del<string> ml new Del<string>(delegate (string str) { return str.Length; });//将Nofity变更为匿名函数 Del<string> ml delegate(string str)…

Ubuntu软件源、pip源大全,国内网站网址,阿里云、网易163、搜狐、华为、清华、北大、中科大、上交、山大、吉大、哈工大、兰大、北理、浙大

文章目录 一、企业镜像源1、阿里云2、网易1633、搜狐镜像4、华为 二&#xff1a;高校镜像源1、清华源2、北京大学3、中国科学技术大学源 &#xff08;USTC&#xff09;4、 上海交通大学5、山东大学6、 吉林大学开源镜像站7、 哈尔滨工业大学开源镜像站8、 西安交通大学软件镜像…

【数据结构OJ题】用栈实现队列

原题链接&#xff1a;https://leetcode.cn/problems/implement-queue-using-stacks/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 用两个栈实现&#xff0c;一个栈进行入队操作&#xff0c;另一个栈进行出队操作。 出队操作&#xff1a; 当出队的栈…

Jmeter对websocket进行测试

JMeterWebSocketSampler-1.0.2-SNAPSHOT.jar下载 公司使用websocket比较奇怪&#xff0c;需要带认证信息进行长连接&#xff0c;通过websocket插件是请求失败&#xff0c;如下图&#xff0c;后面通过代码实现随再打包jar包完成websocket测试 本地实现代码如下&#xff1a; pa…

前馈神经网络解密:深入理解人工智能的基石

目录 一、前馈神经网络概述什么是前馈神经网络前馈神经网络的工作原理应用场景及优缺点 二、前馈神经网络的基本结构输入层、隐藏层和输出层激活函数的选择与作用网络权重和偏置 三、前馈神经网络的训练方法损失函数与优化算法反向传播算法详解避免过拟合的策略 四、使用Python…

【HCIP】08.ISIS中间系统

链路状态协议&#xff0c;传递LSA信息ISIS基于数据链路层封装在OSI时&#xff0c;也有自己的网络层地址和自己的路由协议&#xff0c;即ISIS。之前的ISIS支持OSI的网络层地址&#xff0c;是为OSI中的CLNP&#xff08;无连接网络协议&#xff09;网络设计的路由协议&#xff0c;…

情人节特别定制:多种语言编写动态爱心网页(附完整代码)

写在前面案例1&#xff1a;HTML Three.js库案例2&#xff1a;HTML CSS JavaScript案例3&#xff1a;Python环境 Flask框架结语 写在前面 随着七夕节的临近&#xff0c;许多人都在寻找独特而令人难忘的方式来表达爱意。在这个数字时代&#xff0c;结合创意和技术&#xff0…

计算机视觉入门 3)最大池化

目录 一、最大池化最大池化进行压缩平移不变性 二、代码示例步骤2&#xff1a;图像读取转换步骤2&#xff1a;Filter & ReLU步骤3&#xff1a;Pool 一、最大池化 最大池化进行压缩 在Keras中&#xff0c;通过一个 MaxPool2D 层&#xff0c;将压缩步骤添加到之前的模型中&…

电脑找不到MSVCR120.dll怎么办?MSVCR120.dll是什么?

在我们的日常生活和工作中&#xff0c;电脑故障是难以避免的问题。而MSVCR120.dll文件是Windows系统中的一个重要组件&#xff0c;如果出现损坏或丢失&#xff0c;可能会导致程序无法正常运行&#xff0c;这个问题可能是由于系统文件损坏、病毒感染等原因导致的。因此&#xff…

记录一次wordpress项目的发布过程

背景&#xff1a;发布一套已完成的代码到线上&#xff0c;有完整的代码包&#xff0c;sql文件&#xff0c;环境是linux 宝塔。无wordpress相关经验。 过程&#xff1a;正常的发布代码 问题1&#xff1a;访问自己的域名后跳转到别的域名。 解决&#xff1a; 修改数据表wp_optio…

Apipost中自定义接口字段如何配置

Apipost项目设置中可以配置接口文档中的自定义接口字段&#xff0c;创建状态码字典。分享分档时会展示到文档页面 状态码字典 在状态码字典中可以自定义状态码即其含义 自定义的状态码会在分享的API文档中展示 接口属性 接口属性中可以自定义接口和接口文档展示字段&#xf…

MySQL索引

目录 一、什么是索引 二、索引的原理 三、优缺点 四、分类 1、聚簇索引--顺序IO 2、非聚簇索引--随机IO 五、索引的设计原则 六、创建索引 1、创建表时创建索引 2、在已经存在的表上创建索引 3、使用ALTER TABLE语句来创建索引 1)普通索引 2&#xff09;唯一性索引 …

蓝奥声智能工业安全用电监测与智慧能源解决方案

能源管理变得越来越重要。如今&#xff0c;能源成本已成为国内预算的核心因素&#xff0c;因此用电监控对大多数现代企业来说都很重要。许多企业在日常能源消耗监控中面临着一些挑战&#xff0c;因为它们的规模庞大&#xff0c;基础设施多样化&#xff0c;灵活性低&#xff0c;…

Java之包,权限修饰符,final关键字详解

包 2.1 包 包在操作系统中其实就是一个文件夹。包是用来分门别类的管理技术&#xff0c;不同的技术类放在不同的包下&#xff0c;方便管理和维护。 在IDEA项目中&#xff0c;建包的操作如下&#xff1a; 包名的命名规范&#xff1a; 路径名.路径名.xxx.xxx // 例如&#xff…

sql数据导出到excel

一、打开Navicat Premium 12 二、导出

R语言处理缺失数据(1)-mice

#清空 rm(listls()) gc()###生成模拟数据### #生成100个随机数 library(magrittr) set.seed(1) asd<-rnorm(100, mean 60, sd 10) %>% round #平均60&#xff0c;标准差10 #将10个数随机替换为NA NA_positions <- sample(1:100, 10) asd[NA_positions] <- NA #转…