基于__torch_dispatch__机制的dump方法

基于__torch_dispatch__机制的dump方法

  • 1.参考链接
  • 2.原理
  • 3.代码
  • 4.效果

之前拦截torch和torch.Tensor的办法,在处理backward时,不能看到aten算子的细节.以下基于__torch_dispatch__机制的方案更节约代码,且能看到调用栈

1.参考链接

[原理] (https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)

2.原理

在这里插入图片描述

3.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threadingdevice="cuda"
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
import traceback
from dataclasses import dataclass
from typing import Any@dataclass
class _ProfilerState:cls: Anyobject: Any = Nonelock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):if isinstance(args,torch.Tensor):print(name,index,args.shape)global gindexlock.acquire()torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))gindex+=1lock.release()if isinstance(args,tuple):for idx,x in enumerate(args):save_tensor(name,x,index+idx)class TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef __torch_dispatch__(self, func, types, args=(), kwargs=None):func_packet = func._overloadpacket        if kwargs is None:kwargs = {}        enable_dump=Falseif func_packet.__name__ not in ["detach"]:enable_dump=Trueprint(f"Profiling {func_packet.__name__}") for idx,stack in enumerate(inspect.stack()):print(f'{"*"*idx}{stack.filename}{stack.lineno}')if enable_dump:     save_tensor(f"{func_packet.__name__}-input",args)ret= func(*args, **kwargs)if enable_dump:save_tensor(f"{func_packet.__name__}-output",args)return retclass TorchDumper:_CURRENT_Dumper = Nonedef __init__(self,schedule: Any):self.p= _ProfilerState(schedule) def __enter__(self):assert TorchDumper._CURRENT_Dumper is NoneTorchDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)class Attention(nn.Module):def __init__(self,max_seq_len,head_dim,flash):super().__init__()self.flash = flashself.dropout=0self.attn_dropout = nn.Dropout(self.dropout)self.head_dim=head_dimif not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)mask = torch.triu(mask, diagonal=1).half().to(device)self.register_buffer("mask", mask)		def forward(self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):if self.flash:output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)else:_xk=xk.clone()t=_xk.transpose(2, 3)scores = torch.matmul(xq,t)scores = scores/math.sqrt(self.head_dim)a=self.mask[:, :, :seqlen, :seqlen]scores = scores+ascores = F.softmax(scores.float(), dim=-1)scores = scores.type_as(xq)scores = self.attn_dropout(scores)output = torch.matmul(scores, xv)  return outputdef main(flash,bs, n_local_heads, seqlen, head_dim):torch.random.manual_seed(1)q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)q.data.normal_(0, 0.1)k.data.normal_(0, 0.1)v.data.normal_(0, 0.1)q=Variable(q, requires_grad=True).to(device)k=Variable(k, requires_grad=True).to(device)v=Variable(v, requires_grad=True).to(device)gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)loss_func=nn.CrossEntropyLoss().to(device)model=Attention(seqlen,head_dim,flash).half().to(device)optim = torch.optim.SGD([q,k,v], lr=1.1)with TorchDumper(TorchDumpDispatchMode):for i in range(1):output = model(q,k,v)loss=loss_func(output.reshape(-1,head_dim),gt)loss.backward()  optim.step()print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

4.效果

Profiling clone
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py109
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
Profiling transpose
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py110
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 512, 64])
expand-output 0 torch.Size([8, 8, 512, 64])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 512, 64])
view-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 64, 512])
expand-output 0 torch.Size([8, 8, 64, 512])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 64, 512])
view-output 0 torch.Size([8, 8, 64, 512])
Profiling bmm
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
bmm-input 0 torch.Size([64, 512, 64])
bmm-input 1 torch.Size([64, 64, 512])
bmm-output 0 torch.Size([64, 512, 64])
bmm-output 1 torch.Size([64, 64, 512])
Profiling _unsafe_view

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习高频问答题总结

机器学习问答题总结 第一章 线性回归1.什么是线性回归?解释主要原理2.解释线性回归中最小二乘法的原理吗?3.如何评估线性回归模型的性能?4.线性回归中正则化的目的是什么吗?L1正则化和L2正则化有什么不同? 第二章 逻辑…

# 从浅入深 学习 SpringCloud 微服务架构(六)Feign(1)

从浅入深 学习 SpringCloud 微服务架构(六)Feign(1) 一、Feign 组件概述: Feign 是 Netflix 开发的声明式,模板化的HTTP客户端。 其灵感来自 Retrofit,JAXRS-2.0 以及 WebSocket。 Feign 可帮助我们更加…

代码随想录算法训练营第五十天| 123.买卖股票的最佳时机III ,188.买卖股票的最佳时机IV

题目与题解 123.买卖股票的最佳时机III 题目链接:123.买卖股票的最佳时机III 代码随想录题解:​​​​​​​123.买卖股票的最佳时机III 视频讲解:动态规划,股票至多买卖两次,怎么求? | LeetCode&#xff…

Vector里常用的操作(C++)

1 引言 编程时常用的Vector操作有创建、访问元素、增加元素、删除元素、修改元素、查找索引以及一些常用的函数操作&#xff0c;本文总结了一下这些方法在C里面的实现方式&#xff08;并不是唯一的&#xff09;。 2 创建 操作类别注释C创建创建空向量vector<int>a;创建…

Arduino中增加修改ESP32烧录固件的速度

在Arduino中&#xff0c;默认对ESP32-S3芯片的烧录速度只支持115200、230400、460800、921600这几种速率。只能够在 工具->Upload Speed中选择这些。 有的时候烧录还是觉得太慢了。那么能否更快一些呢&#xff1f; 首先你的USB转串口芯片要支持高速的。常见的芯片速率支持…

Java | 选择排序算法实现

大家可以关注一下专栏&#xff0c;方便大家需要的时候直接查找&#xff0c;专栏将持续更新~ 题目描述 编写一个Java程序&#xff0c;实现选择排序算法。程序需要能够接收一个整型数组作为输入&#xff0c;并输出排序后的数组。 选择排序是一种简单直观的排序算法&#xf…

机械臂模型更换成自己的urdf模块

1.将urdf生成slx文件 smimport(rm_65_flange.urdf);%生成Simscape物理模型 2.更换joint部分&#xff08;对应与几个输入几个输出&#xff09;&#xff08;依次更换&#xff09; 3.更改关节部分&#xff08;依次更换&#xff09; 找到urdf文件夹下的meshes文件夹&#xff0c;看…

electron实现静默打印(各种踩坑解决)

前车之鉴 也是阅读了很多资料和前人踩的坑&#xff0c;直接使用webContent.print方法进行打印。其他方式要不就是Bug多&#xff0c;官方修复也有问题&#xff1b;要不就是官方升级版本后不再支持等 不赘述 需求思路 在main里面实现printerHandle&#xff0c;暴露给渲染线程去…

基于单片机的羽毛球计分器(含proteus仿真和程序)

目录 完整文本及仿真、程序可私信我获取 前言 第一章 设计任务及方案 1.1 设计任务 1.2 总体设计分析 1.3 功能模块方案设计 1.4 方案确定 第二章、硬件设计 2.1 AT89C51 单片机芯片介绍 2.1.1 主要特性 2.1.2 管脚说明 2.1.3 元件清单 2.2 电路介绍 2…

自动化测试用例设计

知人者智&#xff0c;自知者明。大家好&#xff0c;给大家分享一下关于自动化测试用例的设计心得&#xff0c;首先完整的熟悉业务是第一步要做的&#xff0c;不熟悉业务的前提下不会设计出高效且合理的用例&#xff0c;其次是我们要有明确的测试目标&#xff0c;确保我们写的每…

Redis(单/多)线程

一、 Redis 单线程 与 多线程 怎么说&#xff1f; &#xff08;1&#xff09;重要的版本迭代 redis4 之前仅支持 单线程&#xff0c; redis 4之后慢慢 支持多线程&#xff0c; 直到redis6/7后才稳定 &#xff08;2&#xff09;redis 的 工作线程 是 单线程的 &#xff08…

阿里云难题学习笔记

1、下列内存区段增长方是向低地址方向的有&#xff08; &#xff09;&#xff1f; A: 文本段 B: 数据段 C: 堆区 D: 栈区 解析&#xff1a; 在内存管理中&#xff0c;不同的内存区段增长方向是不同的。栈区&#xff08;Stack&#xff09;的增长方向是向低地址方向的&…

Nacos和Eureka有什么区别!!!

一致性模型&#xff1a; Eureka&#xff1a;采用的是 AP&#xff08;Availability, Partition Tolerance&#xff09;模型&#xff0c;即在面临网络分区或部分节点故障时优先保证系统的可用性&#xff0c;牺牲一定的数据一致性。Eureka 通过自我保护机制&#xff0c;允许在节点…

Python构建学生信息管理系统:构建RESTful API - 学生信息管理系统的后端逻辑

在之前的博客里&#xff0c;我们已经完成了项目初始化&#xff0c;在本篇博客中&#xff0c;我们将深入探讨如何使用Flask框架实现学生信息管理系统的后端逻辑&#xff0c;特别是通过RESTful API来实现学生信息的增删改查&#xff08;CRUD&#xff09;操作。 Flask RESTful AP…

go的内存分配机制

Go 语言的内存分配机制可以分为几个主要类别&#xff0c;每个类别都有其特定的行为和优化&#xff1a; 1. 栈&#xff08;Stack&#xff09;分配 局部变量&#xff1a;在函数内部定义的变量通常分配在栈上。大小限制&#xff1a;栈的大小有限&#xff0c;适用于生命周期短、大…

C系统编程:从零手搓一个shell

背景 这么久没更新就是在干这件事&#xff01;&#xff01;因为系统编程已经学的差不多了&#xff0c;所以想找几个项目练练手&#xff0c;之前就一直想写一个自己的shell&#xff01;&#xff01;现在终于有机会实现了。 首先说明一下我的操作系统&#xff1a;Arch linux 服务…

pandas 读取JSON字符串解析长整形丢失数据精度,读取值与实际值不一致

目录 背景&#xff1a; JSON字符串 解析代码 解决方案 背景&#xff1a; 在使用pandas read_json方法读取JSON存为Excel文件时&#xff0c;发现Excel中order_no的值与JSON字符串中的值不一致&#xff0c;开始怀疑是Excel保存精度问题&#xff0c;但是Excel输出实际为字符串…

【OceanBase系列】—— 常用 SQL

作者简介&#xff1a; 花名&#xff1a;绪宁&#xff0c;OceanBase 数据库解决方案架构师 对使用OB过程中常用的一些SQL进行了整理&#xff0c;对应的版本是 4.x。 集群信息 查看版本 show variables like version_comment; 查看集群ID和集群名 show parameters like %clust…

函数的查询

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 在实际使用中经常会需要查询数据库中已有的函数或者某一个函数的内容&#xff0c;下面就介绍一下如何查询函数。 和存储过程类似&#xff0c;这也需要使用到数据字典user_s…

Spring - 4 ( 11000 字 Spring 入门级教程 )

一&#xff1a;Spring IoC&DI 在前⾯的章节中, 我们学习了 Spring Boot 和 Spring MVC 的开发, 可以完成⼀些基本功能的开发了, 但是什么是 Spring 呢? Spring, Spring Boot 和 SpringMVC 又有什么关系呢? 咱们还是带着问题去学习.我们先看什么是Spring 1.1 Spring 是什…