基于__torch_dispatch__机制的dump方法

基于__torch_dispatch__机制的dump方法

  • 1.参考链接
  • 2.原理
  • 3.代码
  • 4.效果

之前拦截torch和torch.Tensor的办法,在处理backward时,不能看到aten算子的细节.以下基于__torch_dispatch__机制的方案更节约代码,且能看到调用栈

1.参考链接

[原理] (https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)

2.原理

在这里插入图片描述

3.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threadingdevice="cuda"
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
import traceback
from dataclasses import dataclass
from typing import Any@dataclass
class _ProfilerState:cls: Anyobject: Any = Nonelock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):if isinstance(args,torch.Tensor):print(name,index,args.shape)global gindexlock.acquire()torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))gindex+=1lock.release()if isinstance(args,tuple):for idx,x in enumerate(args):save_tensor(name,x,index+idx)class TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef __torch_dispatch__(self, func, types, args=(), kwargs=None):func_packet = func._overloadpacket        if kwargs is None:kwargs = {}        enable_dump=Falseif func_packet.__name__ not in ["detach"]:enable_dump=Trueprint(f"Profiling {func_packet.__name__}") for idx,stack in enumerate(inspect.stack()):print(f'{"*"*idx}{stack.filename}{stack.lineno}')if enable_dump:     save_tensor(f"{func_packet.__name__}-input",args)ret= func(*args, **kwargs)if enable_dump:save_tensor(f"{func_packet.__name__}-output",args)return retclass TorchDumper:_CURRENT_Dumper = Nonedef __init__(self,schedule: Any):self.p= _ProfilerState(schedule) def __enter__(self):assert TorchDumper._CURRENT_Dumper is NoneTorchDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)class Attention(nn.Module):def __init__(self,max_seq_len,head_dim,flash):super().__init__()self.flash = flashself.dropout=0self.attn_dropout = nn.Dropout(self.dropout)self.head_dim=head_dimif not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)mask = torch.triu(mask, diagonal=1).half().to(device)self.register_buffer("mask", mask)		def forward(self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):if self.flash:output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)else:_xk=xk.clone()t=_xk.transpose(2, 3)scores = torch.matmul(xq,t)scores = scores/math.sqrt(self.head_dim)a=self.mask[:, :, :seqlen, :seqlen]scores = scores+ascores = F.softmax(scores.float(), dim=-1)scores = scores.type_as(xq)scores = self.attn_dropout(scores)output = torch.matmul(scores, xv)  return outputdef main(flash,bs, n_local_heads, seqlen, head_dim):torch.random.manual_seed(1)q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)q.data.normal_(0, 0.1)k.data.normal_(0, 0.1)v.data.normal_(0, 0.1)q=Variable(q, requires_grad=True).to(device)k=Variable(k, requires_grad=True).to(device)v=Variable(v, requires_grad=True).to(device)gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)loss_func=nn.CrossEntropyLoss().to(device)model=Attention(seqlen,head_dim,flash).half().to(device)optim = torch.optim.SGD([q,k,v], lr=1.1)with TorchDumper(TorchDumpDispatchMode):for i in range(1):output = model(q,k,v)loss=loss_func(output.reshape(-1,head_dim),gt)loss.backward()  optim.step()print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

4.效果

Profiling clone
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py109
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
Profiling transpose
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py110
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 512, 64])
expand-output 0 torch.Size([8, 8, 512, 64])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 512, 64])
view-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 64, 512])
expand-output 0 torch.Size([8, 8, 64, 512])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 64, 512])
view-output 0 torch.Size([8, 8, 64, 512])
Profiling bmm
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
bmm-input 0 torch.Size([64, 512, 64])
bmm-input 1 torch.Size([64, 64, 512])
bmm-output 0 torch.Size([64, 512, 64])
bmm-output 1 torch.Size([64, 64, 512])
Profiling _unsafe_view

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习高频问答题总结

机器学习问答题总结 第一章 线性回归1.什么是线性回归?解释主要原理2.解释线性回归中最小二乘法的原理吗?3.如何评估线性回归模型的性能?4.线性回归中正则化的目的是什么吗?L1正则化和L2正则化有什么不同? 第二章 逻辑…

# 从浅入深 学习 SpringCloud 微服务架构(六)Feign(1)

从浅入深 学习 SpringCloud 微服务架构(六)Feign(1) 一、Feign 组件概述: Feign 是 Netflix 开发的声明式,模板化的HTTP客户端。 其灵感来自 Retrofit,JAXRS-2.0 以及 WebSocket。 Feign 可帮助我们更加…

Arduino中增加修改ESP32烧录固件的速度

在Arduino中,默认对ESP32-S3芯片的烧录速度只支持115200、230400、460800、921600这几种速率。只能够在 工具->Upload Speed中选择这些。 有的时候烧录还是觉得太慢了。那么能否更快一些呢? 首先你的USB转串口芯片要支持高速的。常见的芯片速率支持…

Java | 选择排序算法实现

大家可以关注一下专栏,方便大家需要的时候直接查找,专栏将持续更新~ 题目描述 编写一个Java程序,实现选择排序算法。程序需要能够接收一个整型数组作为输入,并输出排序后的数组。 选择排序是一种简单直观的排序算法&#xf…

机械臂模型更换成自己的urdf模块

1.将urdf生成slx文件 smimport(rm_65_flange.urdf);%生成Simscape物理模型 2.更换joint部分(对应与几个输入几个输出)(依次更换) 3.更改关节部分(依次更换) 找到urdf文件夹下的meshes文件夹,看…

基于单片机的羽毛球计分器(含proteus仿真和程序)

目录 完整文本及仿真、程序可私信我获取 前言 第一章 设计任务及方案 1.1 设计任务 1.2 总体设计分析 1.3 功能模块方案设计 1.4 方案确定 第二章、硬件设计 2.1 AT89C51 单片机芯片介绍 2.1.1 主要特性 2.1.2 管脚说明 2.1.3 元件清单 2.2 电路介绍 2…

自动化测试用例设计

知人者智,自知者明。大家好,给大家分享一下关于自动化测试用例的设计心得,首先完整的熟悉业务是第一步要做的,不熟悉业务的前提下不会设计出高效且合理的用例,其次是我们要有明确的测试目标,确保我们写的每…

Redis(单/多)线程

一、 Redis 单线程 与 多线程 怎么说? (1)重要的版本迭代 redis4 之前仅支持 单线程, redis 4之后慢慢 支持多线程, 直到redis6/7后才稳定 (2)redis 的 工作线程 是 单线程的 &#xff08…

Python构建学生信息管理系统:构建RESTful API - 学生信息管理系统的后端逻辑

在之前的博客里,我们已经完成了项目初始化,在本篇博客中,我们将深入探讨如何使用Flask框架实现学生信息管理系统的后端逻辑,特别是通过RESTful API来实现学生信息的增删改查(CRUD)操作。 Flask RESTful AP…

C系统编程:从零手搓一个shell

背景 这么久没更新就是在干这件事!!因为系统编程已经学的差不多了,所以想找几个项目练练手,之前就一直想写一个自己的shell!!现在终于有机会实现了。 首先说明一下我的操作系统:Arch linux 服务…

函数的查询

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 在实际使用中经常会需要查询数据库中已有的函数或者某一个函数的内容,下面就介绍一下如何查询函数。 和存储过程类似,这也需要使用到数据字典user_s…

Spring - 4 ( 11000 字 Spring 入门级教程 )

一:Spring IoC&DI 在前⾯的章节中, 我们学习了 Spring Boot 和 Spring MVC 的开发, 可以完成⼀些基本功能的开发了, 但是什么是 Spring 呢? Spring, Spring Boot 和 SpringMVC 又有什么关系呢? 咱们还是带着问题去学习.我们先看什么是Spring 1.1 Spring 是什…

更新至2022年上市公司数字化转型数据合集(四份数据合集)

更新至2022年上市公司数字化转型数据合集(四份数据合集) 一、2000-2022年上市公司数字化转型数据(年报词频、文本统计) 二、2007-2022年上市公司数字化转型数据(年报和管理层讨论)(含原始数据…

微前端是如何实现作用域隔离的?

微前端是如何实现作用域隔离的? 一、前言 沙箱(Sandbox)是一种安全机制,目的是让程序运行在一个相对独立的隔离环境,使其不对外界的程序造成影响,保障系统的安全。作为开发人员,我们经常会同沙…

UE5 GAS开发P35,36,37,38,39 将药水修改为AbilitySystem效果

这几节课都是将药水修改成更方便使用的AbilitySystem效果的Actor,分别为增加血量,增加蓝量,暂时获得最大生命值上限 AuraEffectActor.h // Fill out your copyright notice in the Description page of Project Settings. #pragma once #include "CoreMinimal.h" #…

介绍一个开源IOT组态项目

项目介绍 金合可视化平台是一款强大而操作简便的低代码平台,专为满足物联网领域的可视化开发需求而设计。通过该平台,用户可以利用拖拽配置的方式,轻松创建个性化的可视化大屏,无需熟练的编程技能,大幅提高了开发效率。…

图搜索的经典启发式算法A星(A*、A Star)算法详解

文章目录 1. 引言2. 广度优先搜索3. Dijkstra 算法4. 启发式优先搜索(Heuristic)4.1 贪心最佳优先搜索4.2 A*搜索 1. 引言 在许多场景中,我们常会遇到一类问题,即“找到一个位置到另一个位置的距离最短(用时最少&…

使用 Rust 后,我​​使用 Python 的方式发生了变化

使用 Rust 后,我​​使用 Python 的方式发生了变化 Using type hints where possible, and sticking to the classic “make illegal state unrepresentable” principle. 尽可能使用类型提示,并坚持经典的“使非法状态不可表示”原则。 近年来&#xff…

【Pytorch】(十三)PyTorch模型部署: TorchScript

文章目录 (十三)PyTorch模型部署Pytorch动态图的优缺点TorchScriptPytorch模型转换为TorchScripttorch.jit.tracetorch.jit.scripttrace和script的区别总结script 和 trace 混合使用保存和加载模型 (十三)PyTorch模型部署 Pytorc…

科学高效备考AMC8和AMC10竞赛,吃透2000-2024年1850道真题和解析

如何科学、有效地备考AMC8、AMC10美国数学竞赛?多做真题,吃透真题是科学有效的方法之一,通过做真题,可以帮助孩子找到真实竞赛的感觉,而且更加贴近比赛的内容,可以通过真题查漏补缺,更有针对性的…