在 C++/DirectX着色器 的基础上速成CUDA编程,还好思维模式基本通用,就多了线程组排布和共享内存方面的东西,入门还行,高级加速方面就不太行了。
代码仓库:https://github.com/One-sixth/flash-linear-attention-pytorch
虽然相比纯 pytorch 的实现,更大幅度地减少了显存的消耗,但是速度没有办法,缺乏精力和时间去深究 CUDA优化,非常消耗时间。
算了。作为其他类型的 线性注意力算子的起步
5种写法的算子介绍
normal_linear_attention_ops.py
原始方式,显存占用最大,速度最快
可读性:最佳
显存消耗:1X (O^2)
速度:1X
flash_linear_attention_ops.py
原始分块方式
内部使用torch.split,不需要填充到指定长度
可读性:佳
显存消耗:0.7X
速度:0.5X
flash_linear_attention_ops_2.py
基于 flash_linear_attention_ops.py 改为块索引方式,略快一丁点
内部需要填充到指定倍数长度
可读性:佳
显存消耗:0.7X
速度:0.505X
flash_linear_attention_ops_3.py
基于 flash_linear_attention_ops_2.py 加入显式内存复用方式,略快一丁点
即在一开始就分配所有需要的显存,在计算过程中,完全不需要新的显存分配
内部需要填充到指定倍数长度
可读性:中
显存消耗:0.7X
速度:0.51X
flash_linear_attention_ops_4.py
基于 flash_linear_attention_ops_3.py,改为CUDA/C++算子方式
本人的CUDA/C++技术有限,没有精力继续研究了
内部需要填充到指定倍数长度
限制很多,不支持float32以外的数据类型
可以作为其他类型线性注意力的参考实现
算子已经通过 pytorch2.1 + CUDA12.1 环境测试
可读性:较差
显存消耗:0.3X
速度:0.33X