Pytorch官方FlashAttention速度测试

在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。
在这里插入图片描述
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention的实现,一个是使用Pytorch官方的scaled_dot_product_attention接口:

import time
import torch
import torch.nn.functional as Fdef main():repeat = 100device = torch.device("cuda:0")dtype = torch.float16query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)scale_factor = 0.125ori_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()# 原始Self-Attention实现res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ valuetorch.cuda.synchronize(device=device)time_end = time.perf_counter()ori_time_list.append(time_end - time_start)fa_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()with torch.backends.cuda.sdp_kernel(enable_math=False):# 使用Pytorch官方提供的FA实现res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)torch.cuda.synchronize(device=device)time_end = time.perf_counter()fa_time_list.append(time_end - time_start)diff = (res - res_fa).abs().max()ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]avg_ratio = sum(ratio[1:]) / len(ratio[1:])print(f"max diff: {diff}")print(f"avg speed up ratio: {avg_ratio}")if __name__ == '__main__':main()

执行以上代码,终端输出如下:

max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118

这里使用的设备是RTX4070,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的(后面找了个A100的GPU试了下,发现FP32也是有加速的)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/807285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【QT教程】QT6SVG处理

QT6SVG处理 使用AI技术辅助生成 QT界面美化视频课程 QT性能优化视频课程 QT原理与源码分析视频课程 QT QML C扩展开发视频课程 免费QT视频课程 您可以看免费1000个QT技术视频 免费QT视频课程 QT统计图和QT数据可视化视频免费看 免费QT视频课程 QT性能优化视频免费看 免费QT视…

全量知识系统 程序详细设计 之 先验逻辑-实现:从“平凡”回到“平凡” (百度文库QA)

Q1. 思考:数学中的平凡,和程序中的平凡(比如POJO)、语言中的平凡(比如纯文本),数据中的平凡(比如 Number)。因为我设计中的全知系统将设计的三个方面刻画为语言设计、程序…

Prometheus-Grafana基础篇安装绘图

首先Prometheus安装 1、下载 https://prometheus.io/download/ 官网路径可以去这儿下载 2、如图: 3.解压: tar -xf prometheus-2.6.1.linux-amd64 cd prometheus-2.6.1.linux-amd64 4.配置文件说明: vim prometheus.yml 5.启动Promethe…

JavaScript教程(一)--- 语法和数据类型

基础 JavaScript 借鉴了 Java 的大部分语法,但同时也受到 Awk、Perl 和 Python 的影响。 JavaScript 是区分大小写的,并使用 Unicode 字符集。举个例子,可以将单词 Frh(在德语中意思是“早”)用作变量名。 var Frh …

问题解决四步法

一、界定问题 1.问题陈述表 1.问题的定义:smart原则 2.背景信息:对解决问题的影响 3.决策人: 4.利益相关者: 5.成功标准: 6.约束条件: 7.问题边界:包含范围不包含问题,地域范…

Liunx和Windows中重启MySql

MySQL 通常不需要在更改密码或执行大多数配置更改后重启。但在某些情况下,如果您更改了配置文件(如 my.cnf 或 my.ini)或需要重置整个数据库状态,您可能需要重启 MySQL 服务。以下是如何查看 MySQL 状态和重启 MySQL 服务的方法。…

kail渗透工具之nmap的使用方法

准备工作:开启两台虚拟机和一台Windows主机 kail Linux攻击机:192.168.80.131 red hat靶机:192.168.80.129 Windows主机:192.168.252.42 1、nmap扫描工具的简介 nmap是用来探测计算机网络上的主机和服务的一种安全扫描器。为了绘…

20240411金融读报:生态保护补偿条例印发外汇局优化贸易外汇业务通知

1、通过《生态保护补偿条例》(生态保护到位的地方、单位、个人予以激励,市场层面含碳排放权、排污权、用水权、碳汇权益等),6月1日起执行。 (可以搞个一站式领取吗,申请绿色贷款的时候,判断其是…

2024年视频号小店无货源,你一定要尝试一下,出九单收入1W+

大家好,我是电商花花。 如果说去年视频号的流量还差点意思,那么今年的视频号销量一定是非常高的,随着视频号的扩展,也让更多的创业者和博主入驻视频号,让更多人了解到了视频号小店,是这样赚钱的。 首先&am…

系统架构评估_1.相关概念

1.系统架构评估 系统架构评估是在对架构分析、评估的基础上,对架构策略的选取进行决策。它利用数学或逻辑分析技术,针对系统的一致性、正确性、质量属性、规划结果等不同方面,提供描述性、预测性和指令性的分析结果。 2.系统架构评估的方法 …

计算机科学与技术CS考研408资料

在github上整理了考研的一些资料: 内容包括: 王道数据结构、组成原理、操作系统、计算机网络.408笔记PDF408思维导图408真题2009-2021真题无logo版408真题2029-2023王道真题(持续更新)历年真题考频统计灰灰考研择校,…

深水采样器小口径特氟龙材质FEP贝勒管

FEP贝勒管,深水采样器(bailers tube),是一种经济型便携式水质采样器,操作简单,使用方便,性价比高,能大限度的保证样品的真实性。采样管直径很小,能够采取小口径的深水井水样。是一款简单实用&am…

磁盘管理显示u盘无媒体怎么恢复数据

随着科技的发展,U盘已成为我们日常生活和工作中不可或缺的数据存储工具。然而,当我们在使用U盘时,有时会遇到“磁盘管理显示U盘无媒体”的困扰。面对这一问题,许多用户可能会感到惊慌失措,担心数据丢失。本文将为您详细…

c语言例题,计算1/1-1/2+1/3-1/4+1/5……+1/99-1/100的值,打印结果

例题:计算分式1/1-1/21/3-1/41/5……1/99-1/100的值,打印结果 根据题目,我们知道需要计算的是一个固定值, 先定义三个变量来当作分式里的三个值,变量i当作分式里的分母部分,通过for循环来实现分母每次循环…

存储器层次结构

内存 对于内存已经不像曾经那般陌生了,在汇编中,我们大量接触了内存,但是我们还没有对它有个确切、深入的了解。 内存其实叫做随机访问存储器(RAM,Random Access Memory),最基本的存储单位称为…

怎么获取OpenAI的api-key【人工智能】

怎么获取OpenAI的api-key【人工智能】 前言版权推荐怎么获取OpenAI的api-key1.访问控制台2.点击API keys3.点击Start verification4.点击新建密钥 最后 前言 2024-4-11 11:32:06 以下内容源自《【人工智能】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首…

自动挡变速箱 相关的东西研究

1. 原来大众POLO变速箱故障时,详细的维修流程是这样的!_易车 原来大众POLO变速箱故障时,详细的维修流程是这样的!_易车 大众POLO七速干式双离合变速箱OAM 本文主要讲解的是大众POLO双离合变速箱的维修案例,首先说一…

【C++】模版

目录 一、泛型编程二、函数模板2.1 函数模板概念2.2 函数模板格式2.3 函数模板的原理2.4 函数模板的实例化2.5 模板参数的匹配原则 三、类模板3.1 类模板的定义格式3.2 类模板的实例化 四、非类型模板参数五、模板的特化5.1 概念5.2 函数模板特化5.3 类模板特化5.3.1 全特化5.3…

C++初阶:模板进阶

非类型模板参数 模板参数分为类型形参与非类型形参 。 类型形参即:出现在模板参数列表中,跟在 class 或者 typename 之类的参数类型名称 。 非类型形参,就是用一个常量作为类 ( 函数 ) 模板的一个参数,在类 ( 函数 ) 模板中可将…

http与https的区别?

1、HTTP 的最大弊端——不安全 HTTP 之所以被 HTTPS 取代,最大的原因就是不安全,因为http中所有的数据都是明文传输的,自然没有安全性可言,特别是一些敏感数据,比如用户密码和信用卡信息等,一旦被第三方获取…