LLM Algorithms(1): Flash Attention

目录

Background

Flash Attention

Flash Attention Algorithm

参考


NIPS-2022:  Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • idea:减少资源消耗,提升或保持模型性能。
  • 普通attention的空间复杂度是O(N^2) --》降低到Flash Attention O(N)
  • Exact 结果相等。这不是attention的近似计算,Flash Attention的计算结果和原始方法一致。
  • IO aware. 和传统attention相比,Flash Attention会考虑硬件特性,而不是把它当作黑盒。 

Background

Nvidia GPU (GPU性能指标 = FLOPS / GB/s,FLOPS, GPU计算能力--每秒计算速度;GB/s,GPU内存吞吐量

  1. 2016-P100
  2. 2018-V100
  3. 2020-A100
  4. 2022-H100

多年来,GPU的计算能力(FLOPS)的增长速度比增加内存吞吐量(TB/s)更快。 

这两者需要紧密配合去达到数据处理的最优比,但自从硬件失去了这种平衡,我们必须通过软件来进行补偿。因此需要算法能够感知IO (IO-aware)。根据计算和内存访问比例,一个操作可以分为:

  1. 计算受限型 (e.g. 矩阵乘法)
  2. 内存受限型
    1. Element-wise 逐元素操作: activation, dropout, masking.
    2. Reduction 操作: softmax, layer norm, sum. 

element-wise操作是指在计算时只依赖当前值,比如每个元素都乘以2。而reduction依赖所有值(比如整个矩阵或矩阵的行),比如softmax。 

attention的计算时内存受限的,因为它的大部分计算都是element-wise的。 

尽管masking、softmax和dropout操作占用了大部分时间,但大部分FLOPS都用在矩阵乘法中,虽然他们花的时间不多。即数据太庞大,attention计算内存不足,或者说内存利用效率太低!

可以通过内存调整去加速masking、softmax和dropout这些操作呢,但是具体咋办? 

人们都知道把大矩阵切分成小块,但如何保证切分小块的计算结果=原attention计算结果?  

扩展:在计算机体系结构里,内存不是单一的构建,内存存储都是分层的。一般规则是:Memory IO speed 内存速度越快,成本越高,容量越小。

  1. GPU SRAM,19TB/s (20 MB),Static RAM, 静态随机存储器
  2. GPU HBM,1.5TB/s (40 GB),high Boardwidth memory, 高带宽内存 
  3. GPU DRAM,12.8GB/s (>1 TB),main memory

实际上,要充分利用内存、实现IO-aware,关键在于充分利用静态随机存取存储器 (SPAM)比高带宽内存 (HBM)快得多的事实,确保减少两者之间的通信。

(HBM,这是导致CUDA内存溢出的因素之一) 

Flash Attention

Flash Attention 采样分而治之的思想,将大矩阵切块加载到SRAM中,计算每个分块的m和l值。利用上一轮m和l值结合新的子块迭代计算,最终计算出整个矩阵的树枝。Flash Attention基本上可以归结为两个主要思想:

  •  Tiling (在前向和后向传递中使用) - 简单讲就是将NxN的softmax分数矩阵划分为块。
  • 重新计算(因为每个块的系数不一样,Flash Attention每融合一个小块,就需要调整一下之前块的系数,去保持一致!)
  • 传统attention需要分配完整的NxN矩阵(S, P),这是main需要解决的瓶颈,这也是Flash Attention主要解决的问题,将复杂度从O(N^2)降低到O(N)

整个过程不用存储中间变量S和P矩阵,节省了效率因为Attention 操作最大的问题就是每次操作都要从HBM把数据加载到GPU SRAM,运算结束后又从SRAM复制到HBM。这类似于cpu的寄存器与内存的关系,因此最容易的优化方法就是避免这种数据的来回移动,即编译器行话"kernel fusion"。

Flash Attention Algorithm

假设输入一个一维向量x^{(i)} = [x_1,x_2,...,x_B],对应于QK=Sij相似度矩阵中的一行向量。 

1. softmax分块计算:

  • m(x) = max(xi),这是rowmax 操作这是单个值
  • f(x) = [e^{x_1-m(x),..., e^{x_B-m(x)}}]。对应原公式的\tilde{P}_{ij}then why xi-m(x)?这是为了数值稳定,每个数减去相同的任一常量,其softmax值不变。==》减去最大的元素,保证最大值为e^0=1,因为在0~1之间时,浮点数的精度是最大的。
  • l(x) = \sum_if(x)_i,对应原公式\tilde{l}_{ij}这是rowsum 操作
  • so\!ftmax = \frac{f(x)}{l(x)}, softmax除法可以写成diag(l(x))^{-1},把l(x)拉伸成diag的主要原因是把更新公式写成矩阵乘法的形式

2. Flash Attention每次都是合并两块:previous blocks result + latest block。如何保证每一个小块的合并结果与原有attention结果相同?搞好softmax系数的一致性!

  •  因为each step都需要重新计算m(x) = max(m^{(i)}),而m(x)是变的,前面blocks在合并之前,需要先通过m_i - m_i^{new}修正之前block的系数,\tilde{m}_{ij}是指第ij单个block的max(x),不涉及之前blocks的max值
  • m(x) = m([x^{(1), x^{(2)}}]) = max(m(x^{(1)}, m^{(2)}))
  • f(x) = [e^{m(x^{(1)})-m(x))}f(x^{(1)}, e^{m(x^{(2)})-m(x))}f(x^{(2)})]
  • l(x) = e^{m(x^{(1)})-m(x))}l(x^{(1)}, e^{m(x^{(2)})-m(x))}l(x^{(2)})修正系数m_i - m_i^{new}保持一致,因为这两个blocks的softmax系数不一致,m(x^{(2)})-m(x)保证最新的single block的softmax系数与之前的一致!
  • so\!ftmax = \frac{f(x)}{l(x)}

举例:假设x \in R^6,并且它被分成3块:x^{(1)} = [1,3]x^{(2)} = [2,4]x^{(3)} = [3,2]

我们先计算前两块:

  • m(x^{(1)})=3, f(x^{(1)})=[e^{-2},1], l(x^{(1)})=(e^{-2}+1)
  • m(x^{(2)})=4, f(x^{(2)})=[e^{-2},1], l(x^{(2)})=(e^{-2}+1)

我们根据上面的结果计算前两块的结果:

  • m(x) = max(m(x^{(1)}), m(x^{(2)})) = max(3,4)=4
  • f(x) = [e^{3-4}f(x^{(1)}), e^{4-4}f(x^{(2)})]
  • l(x) = e^{3-4}l(x^{(1)}) + e^{4-4}l(x^{(2)})

为什么上面的结果是正确的呢?首先m(x)应该非常明显,4个数中的最大数肯定就是分成两组后的最大中的较大者。而f(x)计算的核心就是在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4以及在𝑓(𝑥(2))𝑓(𝑥(2))前乘以𝑒4−4𝑒4−4。l(x)的计算和f(x)是类似的。为什么需要在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4?因为在计算𝑓(𝑥(1))𝑓(𝑥(1))时最大的数是3,因此前两个数的指数都乘以了𝑒−3𝑒−3。但是现在前4个数的最大是4了,后面两个数的指数乘以了𝑒−4𝑒−4,因此直接合并为[𝑓(𝑥(1)),𝑓(𝑥(2))][𝑓(𝑥(1)),𝑓(𝑥(2))]是不对的,需要把前面两个数再乘以𝑒3−4=𝑒−1𝑒3−4=𝑒−1。而后面两个数本来就乘以了𝑒−4𝑒−4,所以不用变

计算output Oi:我们把一个很大的x拆分成长度为B的blocks,用上面的算法计算block 1和block 2,然后合并其结果;接着计算第3块,并将above 结果与第三块合并; ... =》所以,我们在定义时,可以把空块x=[], m(x)=-inf, f(x)=[], l(x)=0,这样我们就可以把第一块block的计算转换成block 1和空块的合并,使得循环可以从第一块开始!

  • O_1 = diag(l_1)^{-1}(0 * 0 + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)
  •  O_2 = diag(l_i^{new})^{-1}(diag(l_i)O_ie^{m_i-m_i^{new}} + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)

因为Flash Attention不存储中间变量S和P矩阵,所以我们用diag(l_i)O_i反推出之前的PV值,再用e^{m_i-m_i^{new}}修正系数,最后加上第ij块e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j) with single e^{\tilde{m}_{ij}},得到的结果最后再除以diag(l_i^{new})^{-1}保持softmax运算完整性。

参考

Flash Attention论文解读 - 李理的博客

https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/25873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构复习笔记

简答题 (3) 顺序表和链表的概念及异同 顺序表: 把逻辑上相邻的结点储存在物理位置上的相邻储存单元中,结点的逻辑关系由储存单元的邻接关系来体现.链表: 逻辑上相邻的结点存储再物理位置上非连续非顺序的存储单元中, 结点的逻辑关系由指向下一个结点的指针确保.相…

抓包工具 HttpAnalyzerFull_V7.6.4 的下载、安装、使用

目录 一、简介二、下载和安装三、如何注册四、使用介绍4.1 开始、停止、清空监控内容4.2 筛选监控内容4.3 监控内容显示 一、简介 Http Analyzer 是一款功能强大的数据包分析工具,它可以实时监控服务器返回的消息,支持64位Windows系统,可以同…

kaggle竞赛实战9——模型融合

有三种方法, 第一种:均值融合,代码如下 data pd.read_csv(\ result/submission_randomforest.csv\ ) data[randomforest] data[target].values temp pd.read_csv(\ result/submission_lightgbm.csv\ ) …

C++必修:探索C++的内存管理

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:C学习 贝蒂的主页:Betty’s blog 1. C/C的内存分布 我们首先来看一段代码及其相关问题 int globalVar 1; static…

微信小程序毕业设计-网吧在线选座系统项目开发实战(附源码+论文)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

力扣 T62 不同路径

题目 连接 思路 思路1 &#xff1a; BFS爆搜 class Solution { public:queue<pair<int,int>>q;int uniquePaths(int m, int n) {q.push({1,1}); // 起始位置vector<pair<int, int>> actions;actions.push_back({0, 1}); // 向下actions.push_bac…

【网络编程开发】11.IO模型 12.IO多路复用

11.IO模型 什么是IO: IO 是 Input/Output 的缩写&#xff0c;指的是输入和输出。在计算机当中&#xff0c;IO 操作通常指将数据从一个设备或文件中读取到计算机内存中&#xff0c;或将内存中的数据写入设备或文件中。这些设备可以包括硬盘驱动器、网卡、键盘、屏幕等。 通常用…

selenium自动化测试入门 —— Alert/Confirm/Prompt 弹出窗口处理!

一、Alert/Confirm/Prompt弹出窗口特征说明 Alert弹出窗口&#xff1a; 提示用户信息只有确认按钮&#xff0c;无法通过页面元素定位&#xff0c;不关闭窗口无法在页面上做其他操作。 Confirm 弹出窗口&#xff1a; 有确认和取消按钮&#xff0c;该弹出窗口无法用页面元素定…

06_深度学习历史的里程碑--重读AlexNet

1.1 介绍 AlexNet是深度学习历史上一个非常重要的卷积神经网络&#xff08;Convolutional Neural Network, CNN&#xff09;模型&#xff0c;由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年设计并提出。它因在ImageNet大规模视觉识别挑战赛中的卓越表现而闻名&a…

2024世界技能大赛某省选拔赛“网络安全项目”B模块--数据包分析(jsp流量解密)

2024世界技能大赛某省选拔赛“网络安全项目”B模块--数据包分析② 任务一、网络数据包分析取证解析:任务一、网络数据包分析取证解析: A 集团的网络安全监控系统发现有恶意攻击者对集团官方网站进行攻击,并抓取了部分可疑流量包。请您根据捕捉到的流量包,搜寻出网络攻击线…

冯喜运:6.10周一黄金还会再次拉升吗?日内黄金原油操作策略

【黄金消息面分析】&#xff1a;周一(6月10日)亚市盘中&#xff0c;现货黄金交在上周五暴跌后仍然承压&#xff0c;目前金价位于2294美元/盎司左右。因强劲非农数据刺激美元大涨&#xff0c;现货黄金上周五出现暴跌。此外&#xff0c;上周五数据显示&#xff0c;最大黄金消费国…

在python中关于元组的操作

创建元组 如上图所示&#xff0c;a&#xff08;&#xff09;和b tuple(),,这两种方式都可以创建出元组。 在创建元组的时候&#xff0c;指定初始值 如上图所示&#xff0c;也可以在创建元组的时候&#xff0c;指定初始值。 同列表一样元组中的元素也可以是任意类型的。 同列…

Qt 布局管理

布局基础 1)Qt 布局管理系统使用的类的继承关系如下图: QLayout 和 QLayoutItem 这两个类是抽象类,当设计自定义的布局管理器时才会使用到,通常使用的是由 Qt 实现的 QLayout 的几个子类。 2)Qt 使用布局管理器的步骤如下: 首先创建一个布局管理器类的对象。然后使用该…

封装了一个简单理解的iOS竖直文字轮播

效果图 原理 就是持有两个视图&#xff0c;并且两个视图同时改变origin.y 动画结束之后&#xff0c;判断哪个视图是在上面并且看不到的&#xff0c; 则将该视图移动到底部&#xff0c;并且该视图展示下一跳内容 在开始下一轮动画 代码 - (void)startAnimationWithDuration:(…

【Linux】网络配置(静态/动态/手动/nmcli)

目录 一、手动修改网络配置文件&#xff1a;静态 二、手动修改网络配置文件&#xff1a;动态 三、nmcli工具命令修改网络配置文件&#xff1a;静态 四、nmcli工具命令修改网络配置文件&#xff1a;动态 错误排查分析&#xff1a;编辑虚拟网络编辑器不生效 1、排除VMware启…

攻防世界---misc---gif

1、题目描述 2、下载附件&#xff0c;是一堆黑白图片&#xff0c;看到这里我一头雾水 3、看别人写的wp&#xff0c;说是白色表示0&#xff0c;黑色表示1。按照顺序写出来后得到 4、解码的时候&#xff0c;把逗号去掉。二进制转字符串得到&#xff1a; 5、 flag{FuN_giF}

阿里通义千问 Qwen2 大模型开源发布

阿里通义千问 Qwen2 大模型开源发布 Qwen2 系列模型是 Qwen1.5 系列模型的重大升级。该系列包括了五个不同尺寸的预训练和指令微调模型&#xff1a;Qwen2-0.5B、Qwen2-1.5B、Qwen2-7B、Qwen2-57B-A14B 以及 Qwen2-72B。 在中文和英文的基础上&#xff0c;Qwen2 系列的训练数…

深度学习与人工智能

深度学习&#xff0c;是一种特殊的人工智能&#xff0c;他与人工智能及机器学习的关系如下&#xff1a; 近些年来&#xff0c;基于人工神经网络的机器学习算法日益盛行起来&#xff0c;逐渐呈现出取代其他机器学习算法的态势&#xff0c;这主要的原因是因为人工神经网络中有一中…

php高级之框架源码、宏扩展原理与开发

在使用框架的时候我们经常会看到如下代码 类的方法不会显示地声明在代码里面&#xff0c;而是通过扩展的形式后续加进去&#xff0c;这么做的好处是可以降低代码的耦合度、保证源码的完整性、团队开发的时候可以分别写自己的服务去扩展类&#xff0c;减少代码冲突等等。我自己…

C语言之常用字符串函数总结、使用和模拟实现

文章目录 目录 一、strlen 的使用和模拟实现 二、strcpy 的使用及模拟实现 三、strcat 的使用和模拟实现 四、strcmp 的使用和模拟实现 五、strncpy 的使用和模拟实现 六、strncat 的使用和模拟实现 七、strncmp 的使用和模拟实现 八、strstr 的使用和模拟实现 九、st…