在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2
实现了对scaled_dot_product_attention
约2X左右的加速。
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention
的实现,一个是使用Pytorch官方的scaled_dot_product_attention
接口:
import time
import torch
import torch.nn.functional as Fdef main():repeat = 100device = torch.device("cuda:0")dtype = torch.float16query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)scale_factor = 0.125ori_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()# 原始Self-Attention实现res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ valuetorch.cuda.synchronize(device=device)time_end = time.perf_counter()ori_time_list.append(time_end - time_start)fa_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()with torch.backends.cuda.sdp_kernel(enable_math=False):# 使用Pytorch官方提供的FA实现res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)torch.cuda.synchronize(device=device)time_end = time.perf_counter()fa_time_list.append(time_end - time_start)diff = (res - res_fa).abs().max()ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]avg_ratio = sum(ratio[1:]) / len(ratio[1:])print(f"max diff: {diff}")print(f"avg speed up ratio: {avg_ratio}")if __name__ == '__main__':main()
执行以上代码,终端输出如下:
max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118
这里使用的设备是RTX4070
,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention
接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention
的计算结果和直接调用scaled_dot_product_attention
接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的(后面找了个A100的GPU试了下,发现FP32也是有加速的)。