【BBuf的CUDA笔记】十四,OpenAI Triton入门笔记三 FusedAttention

0x0. 前言

继续Triton的学习,这次来到 https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 教程。也就是如何使用Triton来实现FlashAttention V2。对于FlashAttention和FlashAttention V2网上已经有非常多的介绍了,大家如果感兴趣的话我推荐FlashAttention V1看 《图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑》https://zhuanlan.zhihu.com/p/669926191 这篇文章的讲解 以及 FlashAttention V2 看 《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》 https://mp.weixin.qq.com/s/5K6yNj23NmNLcAQofHcT4Q ,原理和公式推导都非常清晰,不过想一口气读完还是要花一些精力的。同时你也可以在 https://github.com/BBuf/how-to-optim-algorithm-in-cuda 找到更多相关资料(此外Meagtron-LM,DeepSpeed等训练Infra框架的迅速跟进也说明了FlashAttention这个系列工作影响之大),例如:

在这里插入图片描述

这篇文章主要的问题是读懂如何使用Triton来实现FlashAttention V2的前向,所以我不会去复述FlashAttention的公式细节,而是从更加工程的角度来说FlashAttention Forward的代码应该如何实现,我在这个过程中也会提供FlashAttention V1/V2 Forward的一个最简Python实现来非常直观的把握代码流程,在这个基础上才会展开对Triton FlashAttention实现的解读,让我们开始吧。(后续如果有精力也会写一下Backward的实现

FlashAttention V1/V2的paper链接为:https://arxiv.org/abs/2205.14135 和 https://tridao.me/publications/flash2/flash2.pdf 。
本文涉及到的实验代码见我的个人仓库:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/triton ,也欢迎大家点star。

0x1. BenchMark

跑了一下 https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 这个教程里的FlashAttention V2的BenchMark。

对于Batch=4,Head=48,HeadDim=64,causal=True的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=64,causal=False的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=64,causal=True的Flash Attention V2 Backward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

在这组配置下Triton在各种Sequence Length下都实现了比cutlass更优的性能,然后在Triton的kernel实现里面有assert Lk in {16, 32, 64, 128},也就是说Triton的实现需要注意力头的隐藏层维度在[16, 32, 64, 128]里,我这里再测一组16的看下表现。

对于Batch=4,Head=48,HeadDim=16,causal=True的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=16,causal=False的Flash Attention V2 Forward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

对于Batch=4,Head=48,HeadDim=16,causal=True的Flash Attention V2 Backward,对比不同序列长度下Triton实现和cutlass实现版本的性能:

在这里插入图片描述

这一组case下虽然Forward Pass还是Triton更快,但是Backward Pass却是cutlass更快了。

另外之前在Triton的issue里面还刷如果HeadDim=128,Triton的Bakcward会比cutlass慢更多:https://github.com/openai/triton/issues/1975 ,参数设置为 BATCH, N_HEADS, N_CTX, D_HEAD = 8, 32, 4096, 128 这里也测试一下:

在这里插入图片描述

反向的耗时对比图:

在这里插入图片描述

结果很神奇,这个反向耗时的差距非常大,且Triton的速度远好于Cutlass的实现,并且随着序列长度的增加Triton的反向的耗时竟然是接近横定的。。保险起见还是建议大家用官方FlashAttention库提供的实现,我现在使用的Triton版本为2.1.0。

0x2. 标准Attention流程以及Python极简实现

从FlashAttention的paper里面截一下标准Attention流程:

在这里插入图片描述

我这里再描述一下流程,首先从HBM中加载 Q , K Q,K Q,K, V V V矩阵,接着执行 S = Q K T S=QK^T S=QKT的计算,并将结果 S S S写回HBM;然后将 S S S再从HBM中读取出来,执行 P = s o f t m a x ( S ) P=softmax(S) P=softmax(S)的计算,再将 P P P写回HBM;然后将 P P P V V V从HBM中读取出来,执行 O = P V O=PV O=PV的计算,最后把结果写回HBM中。对于, Q , K , V , O Q,K,V,O Q,K,V,O,他们的维度都是 N × d N\times d N×d,中间变量 S S S P P P的维度都是 N × N N\times N N×N。这里还有个问题就是对于S和P可能还会有一些其它的操作比如Mask和Dropout,所以上面也提到了有不少的fuse kernel的工作,比如把softmax和mask fuse起来。最后,这里的softmax是PyTorch的softmax算子,也是safe softmax的实现,safe的意思就是在naive softmac的基础上对指数上的每个原始输入值都减掉所有原始输入值中的最大值。具体请参考下面的图片,来源于 https://arxiv.org/pdf/2205.14135.pdf :

在这里插入图片描述

对于safe softmax来说,所有的值都减掉了输入向量里面的最大值,保证了指数部分的最大值是0,避免了数值溢出。

为了验证正确性,我写了一个脚本,这个地方以经典的GPT2为例,然后硬件以A100为例 。这里的 N N N d d d分别设置成1024和64,那么Q,K,V的shape都是 ( N , d ) = ( 1024 , 64 ) (N, d)=(1024, 64) (N,d)=(1024,64),S和P的维度都是 ( N , N ) (N, N) (N,N)

代码实现具体如下:

import torchN, d = 1024, 64  # 更新N和d的值Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))def standard_softmax_attention(Q, K, V):"""执行标准的pytorch softmax和attention计算。"""expected_softmax = torch.softmax(Q @ K.T, dim=1)expected_attention = expected_softmax @ Vreturn expected_softmax, expected_attentiondef safe_softmax_attention(Q, K, V):"""执行安全的softmax和attention计算。"""S_mat = Q @ K.Trow_max = torch.max(S_mat, dim=1).values[:, None]input_safe = S_mat - row_maxsoftmax_numerator = torch.exp(input_safe)softmax_denominator = torch.sum(softmax_numerator, dim=1)[:, None]safe_softmax = softmax_numerator / softmax_denominatormatmul_result = safe_softmax @ Vreturn safe_softmax, matmul_result# 使用标准softmax和attention计算
expected_softmax, expected_attention = standard_softmax_attention(Q_mat, K_mat, V_mat)
# 使用安全softmax和attention计算
safe_softmax, safe_attention = safe_softmax_attention(Q_mat, K_mat, V_mat)# 断言两种方法计算的softmax和attention结果是否接近
assert torch.allclose(safe_softmax, expected_softmax), "error in safe softmax"
assert torch.allclose(safe_attention, expected_attention), "error in safe attention"

测试可以正确通过,也说明了PyTorch的torch.softmax算子的确是用safe softmax的方法来实现的。

0x3. FlashAttention V1 Forward Pass以及Python极简实现

FlashAttention V1通过分块计算的方法,将Q、K和V切块成很多小块,然后将这些切分后的小块放进SRAM(shared memory)中执行计算,最后再写回HBM中。算法流程如下:

在这里插入图片描述

如果你想完全搞清楚这个伪代码的来龙去脉推荐看 https://zhuanlan.zhihu.com/p/669926191 这篇文章,但是从源码实现的角度来看,有了这个伪代码已经接近够了。只需要知道这些看似奇奇怪怪的公式是因为在分块遍历的时候每次计算的是一部分token,而自注意力机制要计算的最终结果是所有token间的,所以从局部到整体的更新就会用到在线的softmax算法以及在线更新最后的输出。这也是上面那堆复杂的公式由来。

我这里尝试用Python来模拟一下这个算法的流程,实现之后对Triton的实现会有帮助,因为从前面几节Triton的教程来看,相比于单纯的Python实现Triton kernel只是多了一个块级别的kernel启动过程而已。沿用上一节GPT2的设置, N N N d d d分别设置成1024和64,那么Q,K,V的shape都是 ( N , d ) = ( 1024 , 64 ) (N, d)=(1024, 64) (N,d)=(1024,64),注意在FlashAttention里面就没有全局的S和P了。假设硬件是A100,A100的Shared Memory大小为192KB=196608B,那么可以计算出这里Flash Attention的分块大小,也就是上面的伪代码的第一行。

B c = M / 4 / 64 = 768 B_c=M/4/64=768 Bc=M/4/64=768 B r = m i n ( 768 , 64 ) = 64 B_r=min(768, 64)=64 Br=min(768,64)=64

然后伪代码的第2行初始化了一个全0的输出矩阵 O O O,shape的大小也是 ( N , d ) = ( 1024 , 64 ) (N, d)=(1024, 64) (N,d)=(1024,64),同时初始化了一个 l l l m m m矩阵,维度大小都是 ( N ) (N) (N),不过 l l l被初始化为全0矩阵, m m m被初始化为负无穷大。

接下来可以根据上面的参数直接计算出 T r T_r Tr T c T_c Tc,对应伪代码的第3行, T r = 向上取整 ( N / B r ) = 1024 / 64 = 16 T_r=向上取整(N/B_r)=1024/64=16 Tr=向上取整(N/Br)=1024/64=16 T c = 向上取整 ( N / B c ) = 1024 / 768 = 2 T_c=向上取整(N/B_c)=1024/768=2 Tc=向上取整(N/Bc)=1024/768=2

接下来的伪代码解析我直接放到下面的Python实现里,每一行代码都可以对应到上面的伪代码:

import torchN, d = 1024, 64  # 更新N和d的值Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))def standard_softmax_attention(Q, K, V):"""执行标准的pytorch softmax和attention计算。"""expected_softmax = torch.softmax(Q @ K.T, dim=1)expected_attention = expected_softmax @ Vreturn expected_softmax, expected_attentiondef flash_attention(Q, K, V, B_r=64, B_c=768):"""使用分块计算和在线softmax校正执行flash attention算法。"""O = torch.zeros((N, d))  # 初始化输出矩阵,对应伪代码的第2行l = torch.zeros((N, 1))  # 存储softmax分母,对应伪代码的第2行m = torch.full((N, 1), -torch.inf)  # 存储每个block的最大值,对应伪代码的第2行# 对应伪代码的第5行,for 1<=j<=T_c,注意这里是把K, V分成了T_c=[N/B_c]块,每一块的大小是[B_c, d]这么大# 所以在python实现的时候就直接通过一个步长为B_c的循环来处理for j in range(0, N, B_c):# 下面三行就对应了伪代码的第6行,Load Kj, Vj from HBM to on-chip SRAM# 但是这里是单纯的 python 实现,我们不可能真的把这一块内存从HBM上放到SRAM上# 这里只是一个伪代码的逻辑说明,可以假装它做到了,因为在Triton里面真的可以在Python层做到。j_end = j + B_cKj = K[j:j_end, :]Vj = V[j:j_end, :]# 对应伪代码的第7行,for 1<=i<T_r,注意这里是把Q分成了Tr=[N/B_r]块,每一块的大小是[B_r, d]这么大# 所以在python实现的时候就直接通过一个步长为B_r的循环来处理for i in range(0, N, B_r):i_end = i + B_rmi = m[i:i_end, :]li = l[i:i_end, :]Oi = O[i:i_end, :]Qi = Q[i:i_end, :]# 对应伪代码的第9行:on chip, compute Sij,Sij的形状是[B_r, B_c]Sij = Qi @ Kj.T# 对应伪代码的第10行mij_hat = torch.max(Sij, dim=1).values[:, None]pij_hat = torch.exp(Sij - mij_hat)lij_hat = torch.sum(pij_hat, dim=1)[:, None]# 对应伪代码的第11行求mi_new的操作,注意这里要对两个张量求整体的max,所以才有这个stack操作mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]# 对应伪代码的第11行求li_new的操作li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat# 对应伪代码的第12行,更新O_i。这里容易有一个疑问,伪代码上有一个diag操作,为什么下面的实现忽略了# 这是因为这个diag是作用在vector上的,实际上是为了在伪代码上能对应上维度,而PyTorch的实现是自动# 支持张量广播机制的,所以这里可以直接计算。O_i = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj# 对应伪代码的第13行,更新m_i,l_i,O_i。m[i:i_end, :] = mi_newl[i:i_end, :] = li_newO[i:i_end, :] = O_ireturn O# 执行flash attention计算
flash_attention_output = flash_attention(Q_mat, K_mat, V_mat)# 执行标准的pytorch softmax和attention计算
expected_softmax, expected_attention = standard_softmax_attention(Q_mat, K_mat, V_mat)# 断言flash attention计算的结果与标准计算结果是否接近
assert torch.allclose(flash_attention_output, expected_attention), "error in flash attention calculation"

需要说明的是在上面的Attention Forward Pass流程中没有考虑到Dropout以及Mask的操作,如果考虑这两个操作整体的流程有一些变化,具体如Flash Attention V1的paper里的Algorithm2所示:

在这里插入图片描述

相比于Algorithm1,多了Mask和Dropout的操作,其它的没有变化。

0x4. FlashAttention V2 Forward Pass以及Python极简实现

如果你想很清晰的了解FlashAttention V2背后的改进原理请阅读 《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》 https://mp.weixin.qq.com/s/5K6yNj23NmNLcAQofHcT4Q 。我这里只做一个简单的原理解析,重点是关注代码层面相比于FlashAttention V1 Forward Pass的变化,并基于FlashAttention V1的版本实现FlashAttention V2 Forward Pass。

有了上一节代码的铺垫,Flash Attention V1 Forward Pass其实可以抽象为下面的图(从上面的《图解大模型计算加速系列:Flash Attention V2,从原理到并行计算》文章copy来的):

在这里插入图片描述

这个图和我们的Flash Attention V1实现是完全对应的,需要注意的是图中有6个O的小块,但实际上横着的O只有一个并且是逐步更新的,这里是为了体现分块的思想才画出来的。

这里以 O 0 O_0 O0为例子,我们可以看到 O 00 O_{00} O00 O 01 O_{01} O01共用了 Q 0 Q_0 Q0,FlashAttention V2基于这个观察调整了Flash Attention V1的循环顺序,现在外层循环遍历Q不就可以避免重复访问Q了吗?调整训练的顺序只是FlashAttention V2的操作之一,另外两个比较重要的操作是对计算公式进行了改写尽量减少non-matmul FLOPs,具体来说在计算局部attention时,先不考虑softmax的分母以及将rescale的时机后移,只能感叹作者大佬的数学太强,具体的大家可以参考一下《FlashAttention2详解(性能比FlashAttention提升200%)》https://zhuanlan.zhihu.com/p/645376942 这篇文章的Algorthm的解释。此外,Paper中还提了一个重要的并行性方面的改进,即加入了序列并行,具体说来 FlashAttention V1 在 batch 和 heads 两个维度上进行了并行化,使用一个线程块来处理一个注意力头,总共需要的线程块的数量等于batch和注意力头的乘积。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为这几乎可以有效利用GPU上所有计算资源。但是在处理长序列输入(目前训练100k,200k的长文本模型需求逐步增长)时,由于内存限制,通常会减小batch和注意力头数量,这样GPU并行化程度就降低了。基于此,FlashAttention-2在序列长度这一维度上进行并行化,显著提升了GPU的并行度并提升了性能。这些改进我们都可以在 https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py 这个Triton实现中找到,留在下一节细讲。

这里仍然是贴出Flash AttentionV2的算法伪代码,并且使用Python来模拟一下流程。

在这里插入图片描述

对应的python代码以及流程如下,由于这里只考虑了forward pass所以代码里只计算了Attention的输出O没有计算logsumexp L(这个是给backward pass用的):

import torchN, d = 1024, 64  # 更新N和d的值Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))def standard_softmax_attention(Q, K, V):"""执行标准的PyTorch softmax和attention计算。"""expected_softmax = torch.softmax(Q @ K.T, dim=1)expected_attention = expected_softmax @ Vreturn expected_softmax, expected_attentiondef flash_attention_v2(Q, K, V, B_r=64, B_c=768):"""使用分块计算和在线softmax校正执行flash attention v2算法。"""O = torch.zeros((N, d))  # 初始化O为(N, d)的形状,实际上对应伪代码第5行的O初始化l = torch.zeros((N, 1))  # 初始化l为(N)的形状,实际上对应伪代码第5行的l初始化m = torch.full((N, 1), -torch.inf)  # 存储每个block的最大值,初始化为负无穷大,对应伪代码的第5行# 对应伪代码的第3行,for 1<=i<T_r,注意这里是把Q分成了Tr=[N/B_r]块,每一块的大小是[B_r, d]这么大# 所以在python实现的时候就直接通过一个步长为B_r的循环来处理for i in range(0, N, B_r):Qi = Q[i:i+B_r, :]# 对应伪代码的第 6 行,for 1<=j<=T_c,注意这里是把K, V分成了T_c=[N/B_c]块,每一块的大小是[B_c, d]这么大# 所以在python实现的时候就直接通过一个步长为B_c的循环来处理 for j in range(0, N, B_c):  # 内循环遍历Q的块Kj = K[j:j+B_c, :]Vj = V[j:j+B_c, :]# 对应伪代码的第8行:on chip, compute Sij,Sij的形状是[B_r, B_c]Sij = Qi @ Kj.T# 对应伪代码的第9行求m_i^(j)的操作,mi_new的形状是B_rmi_new = torch.max(torch.column_stack([m[i:i+B_r], torch.max(Sij, dim=1).values[:, None]]), dim=1).values[:, None]# 对应伪代码的第9行求Pij_hat的操作,Pij_hat的形状是(B_r x B_c),和Sij一致Pij_hat = torch.exp(Sij - mi_new)# 对应伪代码的第9行求lij的操作l[i:i+B_r] = torch.exp(m[i:i+B_r] - mi_new) * l[i:i+B_r] + torch.sum(Pij_hat, dim=1)[:, None]# 对应伪代码的第10行求O_ij的操作O[i:i+B_r] = O[i:i+B_r] * torch.exp(m[i:i+B_r] - mi_new) + Pij_hat @ Vjm[i:i+B_r] = mi_newO = O / l  # 对应伪代码第12行,根据softmax的分母校正输出return O# 执行flash attention计算
flash_attention_v2_output = flash_attention_v2(Q_mat, K_mat, V_mat)# 执行标准的PyTorch softmax和attention计算
_, expected_attention = standard_softmax_attention(Q_mat, K_mat, V_mat)# 断言flash attention计算的结果与标准计算结果是否接近
assert torch.allclose(flash_attention_v2_output, expected_attention), "Error in flash attention calculation"

然后FlashAttention V2里面还有两节和GPU并行性相关的话,在对Triton实现的解读之前我先把这两节翻译一下。

在这里插入图片描述

翻译:FlashAttention V1在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为几乎可以有效利用GPU上所有计算资源。

但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。因此,FlashAttention V2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。

Forward pass 这里大概就是说,FlashAttention V1伪代码中有两个循环,K,V在外循环j,Q在内循环i。FlashAttention V2将Q移到了外循环i,K,V移到了内循环 j,由于改进了算法使得warps之间不再需要相互通信去处理,所以外循环可以放在不同的 thread block 上。这个交换的优化方法是由Phil Tillet在Triton提出并实现的,也就是下一节要解读的Triton代码了。我们会看到它启动kernel的时候线程网格有两个维度,其中一个维度是序列长度,另外一个维度是batch和注意力头数的乘积。

在这里插入图片描述

翻译:paper 的3.2节讨论了如何分配thread block,然而在每个thread block内部,我们也需要决定如何在不同的warp之间分配工作。我们通常在每个thread block中使用4或8个warp,如Figure3所示。

FlashAttention forward pass. 这里精简一下,如Figure3所示,外循环对K,V在输入序列N上遍历,内循环对Q在N上遍历。对于每个块,FlashAttention V1将K和V分别分成4个warp,并且所有的warp都可以访问Q。K的warp乘以Q得到S的一部分 S i j S_{ij} Sij,然后 S i j S_{ij} Sij经过局部softmax后还需要乘以V的一部分得到 O i O_i Oi。但是,每次外循环 j + + j++ j++都要更新一次 O i O_i Oi(对上一次的 O i O_i Oi先rescale再加上当前的值),这就导致每个warp需要从HBM里面频繁读写 O i O_i Oi来累计最后的结果,这种方案也被称为"Split-K"方案,整体是低效的,因为所有warp都需要从HBM频繁读写中间结果 ( Q i , O i , m i , l i ) (Q_i, O_i, m_i, l_i) (Qi,Oi,mi,li)。FlashAttention V2 将Q移到了外循环i,K,V移到了内循环j,并将Q分为4个warp,所有warp都可以访问K,V。这样做的好处是,原来FlashAttention每次内循环i++会导致 O i O_i Oi也变换(而 O i O_i Oi需要通过HBM读写),现在每次内循环j++处理的都是 O i O_i Oi,此时 O i O_i Oi是存储在SRAM上的,代价远小于HBM。

0x5. FlashAttention V2 Forward Pass Triton 实现解读

有了上面的铺垫,就可以直接来看Triton的实现了,这里只关注 Forward Pass部分,Triton的核心计算逻辑在下面的这个函数:

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #K_block_ptr, V_block_ptr,  #start_m, qk_scale,  #BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,  #STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #N_CTX: tl.constexpr):# range of values handled by this stage# 根据STAGE的值,函数定义了处理的键(K)和值(V)的范围。# 不同的STAGE对应不同的处理范围,支持因果(causal)和非因果(non-causal)的自注意力。if STAGE == 1: # causal = Truelo, hi = 0, start_m * BLOCK_Melif STAGE == 2:lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_Mlo = tl.multiple_of(lo, BLOCK_M)# causal = Falseelse:lo, hi = 0, N_CTX# 使用tl.advance函数调整K和V指针的位置,以正确地从相应的内存位置加载数据。K_block_ptr = tl.advance(K_block_ptr, (0, lo))V_block_ptr = tl.advance(V_block_ptr, (lo, 0))# loop over k, v and update accumulator# 在一个循环中,函数加载键(K)的一个块,计算查询(Q)与这个键块的点积,# 然后根据当前STAGE调整计算结果。如果是STAGE 2并且因果关系为真,会应用一个掩码来屏蔽未来的信息。for start_n in range(lo, hi, BLOCK_N):start_n = tl.multiple_of(start_n, BLOCK_N)# -- compute qk ----k = tl.load(K_block_ptr)qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)qk += tl.dot(q, k)if STAGE == 2:mask = offs_m[:, None] >= (start_n + offs_n[None, :])qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)m_ij = tl.maximum(m_i, tl.max(qk, 1))qk -= m_ij[:, None]else:# 对应算法流程伪代码的第9行的m_ij的计算,和伪代码的区别是这里应用了qk_scalem_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)qk = qk * qk_scale - m_ij[:, None]# 计算p,对应伪代码的第9行的p的计算p = tl.math.exp2(qk)l_ij = tl.sum(p, 1)# -- update m_i and l_ialpha = tl.math.exp2(m_i - m_ij)l_i = l_i * alpha + l_ij# -- update output accumulator --acc = acc * alpha[:, None]# update accv = tl.load(V_block_ptr)acc += tl.dot(p.to(tl.float16), v)# update m_i and l_im_i = m_ijV_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))return acc, l_i, m_i

需要说明的是这个_attn_fwd_inner函数负责的是一小块Q(入参中的q)和KV的计算,代码中的for循环对应的就是伪代码中的对KV的循环,而Q的循环实际上是体现在triton kernel启动的设置,见下面的代码和注释:

# 定义了一个_attention类,它继承自torch.autograd.Function。这允许我们自定义一个操作的前向和后向传播
#(即计算梯度的方式),使其能够与PyTorch的自动梯度计算系统无缝集成。
class _attention(torch.autograd.Function):@staticmethod# forward方法定义了这个自定义操作的前向传播逻辑。ctx是一个上下文对象,用于存储用于反向传播的信息。# q, k, v分别代表query, key, value三个输入Tensor,causal和sm_scale是额外的控制参数。def forward(ctx, q, k, v, causal, sm_scale):# shape constraints# 这几行代码检查输入Tensor的最后一个维度,确保它们的大小相等且为特定的值(16, 32, 64, 或 128)。这是由于实现的特定性能优化需要。Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]assert Lq == Lk and Lk == Lvassert Lk in {16, 32, 64, 128}# 初始化一个与q相同形状和类型的空Tensoro,用于存储输出结果。o = torch.empty_like(q)# 这几行设置了几个关键的性能调优参数,包括处理块的大小(BLOCK_M, BLOCK_N)和# 计算阶段的数量(num_stages)。num_warps指的是每个CUDA block中warp的数量。BLOCK_M = 128BLOCK_N = 64 if Lk <= 64 else 32num_stages = 4 if Lk <= 64 else 3num_warps = 4stage = 3 if causal else 1# 根据CUDA设备的能力(这里检查的是计算能力9.x,即NVIDIA Volta架构及以后的架构),进一步调整num_warps和num_stages。# Tuning for H100if torch.cuda.get_device_capability()[0] == 9:num_warps = 8num_stages = 7 if Lk >= 64 else 3# 计算Triton kernel的网格尺寸。triton.cdiv是一个辅助函数,用于计算向上取整的除法。# q.shape[2]是序列长度,q.shape[0]和q.shape[1]分别是batch和seq lengthgrid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)# 初始化另一个TensorM,用于在计算过程中存储中间结果。M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)# 调用Triton kernel _attn_fwd执行实际的attention计算。这里传递了大量的参数,包括输入Tensor的各个维度、步长(stride)、形状、调优参数等。_attn_fwd[grid](q, k, v, sm_scale, M, o,  #q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #q.shape[0], q.shape[1],  #N_CTX=q.shape[2],  #BLOCK_M=BLOCK_M,  #BLOCK_N=BLOCK_N,  #BLOCK_DMODEL=Lk,  #STAGE=stage,  #num_warps=num_warps,  #num_stages=num_stages  #)

这里的triton.cdiv(q.shape[2], BLOCK_M)其实就是对Q进行分块,需要说明的是这个地方输入的Q,K,V的形状是(Batch, NHeads, Seq, HeadDim),所以这里启动的线程网格有2个维度都是有值的,除了x维度为triton.cdiv(q.shape[2], BLOCK_M),它的y维度则为q.shape[0] * q.shape[1]的乘积(这里的x是在序列维度上切分也导致了后面构造内存指针的时候有一个特殊的order=(1, 0),参数)。也就是说这里的Block数量其实是比较多的,更容易让GPU的SM用满,这个启动方式和FlashAttention V2 paper中提到的启动方式是一致的,具体请看上一节的末尾翻译部分。至于,我们在计算的时候使用多少个warp,这个也是和Paper的设置保持一致,一般是用4个,只有针对H100才用8个。另外就是由于现在的Q,K,V形状和paper中的 ( N , d ) (N, d) (N,d)不一样,所以分块的个数也是不一样的,这里是写死了分块数:

BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32

最后还有一个_attn_fwd要解析,内容如下:

@triton.jit
# 定义了一个名为_attn_fwd的函数。这个函数是实现注意力机制前向pass的kernel。函数参数包括输入的Query(Q)、Key(K)、Value(V)张量,
# softmax缩放因子(sm_scale),一个中间计算结果(M)和输出张量(Out),以及多个关于这些张量的步长(stride)参数和其他配置常量。
def _attn_fwd(Q, K, V, sm_scale, M, Out,  #stride_qz, stride_qh, stride_qm, stride_qk,  #stride_kz, stride_kh, stride_kn, stride_kk,  #stride_vz, stride_vh, stride_vk, stride_vn,  #stride_oz, stride_oh, stride_om, stride_on,  #Z, H,  #N_CTX: tl.constexpr,  #BLOCK_M: tl.constexpr,  #BLOCK_DMODEL: tl.constexpr,  #BLOCK_N: tl.constexpr,  #STAGE: tl.constexpr  #):# 注意,输入参数里的Z和H分别表示batch size和注意力头数# start_m表示当前kernel program 实例对应的seq维度的偏移,而off_hz表示的是batch*heads维度的偏移。start_m = tl.program_id(0)off_hz = tl.program_id(1)# 这些行计算了两个偏移量off_z和off_h,它们分别代表在batch(或heads)中的位置。off_z = off_hz // Hoff_h = off_hz % H# 计算用于定位Q、K和V张量中当前处理块的偏移量。这是基于先前计算的偏移量和提供的步长参数。qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh# block pointers# 使用tl.make_block_ptr创建一个指向Q张量当前处理块的指针。这个函数调用指定了基础地址、形状、步长、偏移量和块形状等,以及如何在内存中访问这个数据块。# N_CTX 是q.shape[2],表示的是序列长度,BLOCK_DMODEL是Lk,表示的是每个注意力头的隐藏层维度大小# 下面几个make_block_ptr创建的张量类似,分别是对K,V以及输出O创建指向当前处理块的指针Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset,shape=(N_CTX, BLOCK_DMODEL),strides=(stride_qm, stride_qk),offsets=(start_m * BLOCK_M, 0),block_shape=(BLOCK_M, BLOCK_DMODEL),order=(1, 0),)V_block_ptr = tl.make_block_ptr(base=V + qvk_offset,shape=(N_CTX, BLOCK_DMODEL),strides=(stride_vk, stride_vn),offsets=(0, 0),block_shape=(BLOCK_N, BLOCK_DMODEL),order=(1, 0),)K_block_ptr = tl.make_block_ptr(base=K + qvk_offset,shape=(BLOCK_DMODEL, N_CTX),strides=(stride_kk, stride_kn),offsets=(0, 0),block_shape=(BLOCK_DMODEL, BLOCK_N),order=(0, 1),)O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset,shape=(N_CTX, BLOCK_DMODEL),strides=(stride_om, stride_on),offsets=(start_m * BLOCK_M, 0),block_shape=(BLOCK_M, BLOCK_DMODEL),order=(1, 0),)# initialize offsets# 计算M维度(seq维度)上每个线程应处理的元素的起始偏移量。offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)# 计算N维度(batch*heads维度)上每个线程应处理的元素的偏移量。offs_n = tl.arange(0, BLOCK_N)# initialize pointer to m and l# 初始化m向量,m用于存储每个m维度上的最大logit,初始化为负无穷大。m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")# 初始化l向量,l用于累计softmax的分母,初始化为1。l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0# 初始化累加器,用于累积注意力加权和。注意这里的shape是(BLOCK_M, BLOCK_DMODEL)acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)# load scalesqk_scale = sm_scale     # 加载softmax缩放因子。qk_scale *= 1.44269504  # 将softmax缩放因子乘以1/log(2),用于后续计算。# load q: it will stay in SRAM throughoutq = tl.load(Q_block_ptr) # 将Q矩阵的当前块加载到SRAM中,此数据在整个计算过程中保持不变。# stage 1: off-band# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGEif STAGE & 1:acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #start_m, qk_scale,  #BLOCK_M, BLOCK_DMODEL, BLOCK_N,  #4 - STAGE, offs_m, offs_n, N_CTX  #)# stage 2: on-bandif STAGE & 2:# barrier makes it easier for compielr to schedule the# two loops independentlytl.debug_barrier()acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #start_m, qk_scale,  #BLOCK_M, BLOCK_DMODEL, BLOCK_N,  #2, offs_m, offs_n, N_CTX  #)# epiloguem_i += tl.math.log2(l_i)acc = acc / l_i[:, None]m_ptrs = M + off_hz * N_CTX + offs_mtl.store(m_ptrs, m_i)tl.store(O_block_ptr, acc.to(Out.type.element_ty))

需要特别注意的是这段代码最后的epilogue部分就对应了FlashAttention V2伪代码中的12行以后的内容,根据softmax的分母部分较正输出。此外,Triton的实现里面考虑了一些paper里面没有的东西比如qk_scalecausal mask,对Q*K的结果S应用了减掉m,使得整个实现看起来要复杂不少,但整体的算法逻辑和并行设置和paper还是一致的。

0x6. 总结

这篇文章主要是对FlasAttention V1/V2进行简单的原理解析和Python精简实现,然后重点是阅读FlashAttention V2的Triton代码实现并做了Benchmark对比。

0x7. 相关资料

  • https://zhuanlan.zhihu.com/p/646084771
  • https://tridao.me/publications/flash2/flash2.pdf
  • https://zhuanlan.zhihu.com/p/681154742
  • https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
  • https://mp.weixin.qq.com/s/5K6yNj23NmNLcAQofHcT4Q

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712838.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Win11系统安装安卓子系统教程

随着Win11系统的不断普及&#xff0c;以及硬件设备的更新换代&#xff0c;我相信很多同学都已经更新并使用到了最新的Win11系统。那么&#xff0c;Win11系统最受期待的功能“Windows Subsystem for Android”&#xff08;简称WSA&#xff09;&#xff0c;即《安卓子系统》。他可…

spring.factories的常用配置项

概述 spring.factories 实现是依赖 spring-core 包里的 SpringFactoriesLoader 类&#xff0c;这个类实现了检索 META-INF/spring.factories 文件&#xff0c;并获取指定接口的配置的功能。 Spring Factories机制提供了一种解耦容器注入的方式&#xff0c;帮助外部包&am…

掘根宝典之C语言字符串输入函数(gets(),fgets(),get_s())

字符串输入前的注意事项 如果想把一个字符串读入程序&#xff0c;首先必须预留该字符串的空间&#xff0c;然后用输入函数获取该字符串 这意味着必须要为字符串分配足够的空间。 不要指望计算机在读取字符串时顺便计算它的长度&#xff0c;然后再分配空间(计算机不会这样做&a…

ai图生文的软件!分享4个受欢迎的!

在数字化时代&#xff0c;随着人工智能技术的飞速发展&#xff0c;AI图生文软件已经成为自媒体人、创作者和广告从业者手中的得力助手。这些软件能够将静态的图片转化为生动的文字&#xff0c;为图片注入灵魂&#xff0c;让观者仿佛置身于画面之中。今天&#xff0c;就让我们一…

LabVIEW和Python开发微细车削控制系统

LabVIEW和Python开发微细车削控制系统 为满足现代精密加工的需求&#xff0c;开发了一套基于LabVIEW和Python的微细车削控制系统。该系统通过模块化设计&#xff0c;实现了高精度的加工控制和G代码的自动生成&#xff0c;有效提高了微细车削加工的自动化水平和编程效率。 项目…

1950-2022年各省逐年平均降水量数据

1950-2022年各省逐年平均降水量数据 1、时间&#xff1a;1950-2022年 2、指标&#xff1a;省逐年平均降水量 3、范围&#xff1a;33省&#xff08;不含澳门&#xff09; 4、指标解释&#xff1a;逐年平均降水数据是指当年的日降水量的年平均值&#xff0c;不是累计值&#…

ONLYOFFICE 桌面编辑器 v8.0 更新内容详细攻略

文章目录 引言PDF 表单RTL 支持电子表格中的新增功能Moodle 集成用密码保护 PDF 文件从“开始”菜单快速创建文档本地界面主题下载安装桌面编辑工具总结 引言 官网链接&#xff1a; ONLYOFFICE 官方网址 ONLYOFFICE 桌面编辑器是一款免费的文档处理软件&#xff0c;适用于 Li…

uniapp实现-审批流程效果

一、实现思路 需要要定义一个变量, 记录当前激活的步骤。通过数组的长度来循环数据&#xff0c;如果有就采用3元一次进行选择。 把循环里面的变量【name、status、time】, 全部替换为取出的那一项的值。然后继续下一次循环。 虚拟的数据都是请求来的, 组装为好渲染的格式。 二…

【打工日常】使用docker部署在线PDF工具

一、Stirling-PDF介绍 Stirling-PDF是一款功能强大的本地托管的基于 Web 的 PDF 操作工具&#xff0c;使用 docker部署。该自托管 Web 应用程序最初是由ChatGPT全权制作的&#xff0c;现已发展到包含广泛的功能来处理您的所有 PDF 需求。允许对 PDF 文件执行各种操作&#xff0…

基于session注册JAva篇springboot

springboot3全家桶&#xff0c;数据库 &#xff1a;redis&#xff0c;mysql 背景环境&#xff1a;邮箱验证码&#xff0c;验证注册 流程&#xff1a;先通过邮箱验证&#xff0c;发送验证码&#xff0c;将获取到的session和验证码&#xff0c;存入redis里&#xff08;发送邮箱…

【leetcode】链表的回文结构

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家刷题&#xff0c;如果你觉得我写的还不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 点击查看题目 思路: 1.找中间节点 找中间节点的方法在下面这个博文中详细提过 【点击进入&#xff1a;【l…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:布局约束)

通过组件的宽高比和显示优先级约束组件显示效果。 说明&#xff1a; 从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 aspectRatio aspectRatio(value: number) 指定当前组件的宽高比。 卡片能力&#xff1a; 从API vers…

浅谈 Linux 孤儿进程和僵尸进程

文章目录 前言孤儿进程僵尸进程 前言 本文介绍 Linux 中的 孤儿进程 和 僵尸进程。 孤儿进程 在 Linux 中&#xff0c;就是父进程已经结束了&#xff0c;但是子进程还在运行&#xff0c;这个子进程就被称作 孤儿进程。 需要注意两点&#xff1a; 孤儿进程最终会进入孤儿院…

软考-计算题

1.二维矩阵转换成一维矩阵 2.算术表达式&#xff1a; 3.计算完成项目的最少时间&#xff1a;之前和的max&#xff08;必须之前的所有环节都完成&#xff09; 松弛时间&#xff1a;最晚开始时间-最早开始时间 最早&#xff1a;之前环节都完成的和的max 最晚&#xff1a;总时间…

黑猫的牌面

解法&#xff1a; 桶 #include <iostream> #include <vector> #include <algorithm> using namespace std; #define endl \nint main() {ios::sync_with_stdio(false);cin.tie(0); cout.tie(0);vector<int> tong(1001);int t 4;int k, pai;long lon…

LeetCode 每日一题 树合集 Day 16 - 27

终于是开学了&#xff0c;想了想每日一更频率太高&#xff0c;以后每周更新一周的每日一题。 103. 二叉树的锯齿形层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c…

嵌入式开发——面试题操作系统(调度算法)

linux7种进程调度算法 1&#xff1a;先来先服务&#xff08;FCFS&#xff09;调度算法 原理&#xff1a;按照进程进入就绪队列的先后次序进行选择。对于进程调度来说&#xff0c;一旦一个进程得到处理机会&#xff0c;它就一直运行下去&#xff0c;直到该进程完成任务或者因等…

阿里云降价,这泼天的富贵你接不接?附云服务器价格表

阿里云能处&#xff0c;关键时刻ta真降价啊&#xff01;2024新年伊始阿里云带头降价了&#xff0c;不只是云服务器&#xff0c;云数据库和存储产品都降价&#xff0c;阿里云新老用户均可购买99元服务器、199元服务器&#xff0c;续费不涨价&#xff0c;阿里云百科aliyunbaike.c…

【力扣hot100】刷题笔记Day17

前言 今天竟然不用开组会&#xff01;天大的好消息&#xff0c;安心刷题了 46. 全排列 - 力扣&#xff08;LeetCode&#xff09; 回溯&#xff08;排列&#xff09; class Solution:def permute(self, nums: List[int]) -> List[List[int]]:# 回溯def backtrack():if len(…

关于游戏报错提示x3daudio1_7.dll丢失怎么修复?多个实测有效方法分享

x3daudio1_7.dll 是一个与 Microsoft DirectX 相关的重要动态链接库&#xff08;DLL&#xff09;文件&#xff0c;它主要服务于Windows操作系统下的多媒体和游戏应用程序。 一、以下是关于 x3daudio1_7.dll 文件的详细介绍 名称与位置&#xff1a; 文件名&#xff1a;x3daud…