- FlashAttention最基础的方案来自使用高速的share memory来加速Softmax操作,实现Softmax的tiling方案。(Q,K,V之间的乘法可由gemm实现。)
左侧为GPU各部分的访问速度比较
- FlashAttention使用平铺来防止大型实体化𝑁 ×𝑁 注意力矩阵(虚线框)在(相对)慢的GPU HBM上。
中间为实现过程
- softmax的计算公式
注:我也比较好奇,softmax公式怎么好像变得复杂了?我在参考文献60中找到了答案:
不幸的是,在所表示的数字范围有限的实际硬件上,算法1的第3行(求分母的时候)可能由于指数而上溢或下溢。得到这这种安全形式的改写。 - 作者提出的分解方法
右侧为融合核函数和pytorch实现的速度比较
CG
-
https://github.com/Dao-AILab/flash-attention
-
Jax上继承了Numpy计算加速,XLA加速,JIT编译,自动微分等,以下代码不用自己实现cuda函数Implementation of Flash Attention in Jax
-
cuda实现 https://github.com/lucidrains/flash-cosine-sim-attention/tree/main
-
https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization
-
https://github.com/kyegomez/FlashAttention20Triton
-
https://github.com/Lightning-AI/lit-llama
-
Add Flash-Attention to Huggingface Models https://github.com/conceptofmind/flash-gpt
-
https://www.zhihu.com/question/611236756/answer/3136806315