拦截pytorch算子,dump输入输出

拦截pytorch算子,dump输入输出

  • 一.代码
  • 二.输出

希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.

一.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Templatedevice="cuda"class Attention(nn.Module):def __init__(self,max_seq_len,head_dim,flash):super().__init__()self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')self.dropout=0self.attn_dropout = nn.Dropout(self.dropout)self.head_dim=head_dimif not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)mask = torch.triu(mask, diagonal=1).half().to(device)self.register_buffer("mask", mask)		def forward(self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):if self.flash:output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)else:_xk=xk.clone()t=_xk.transpose(2, 3)scores = torch.matmul(xq,t)scores = scores/math.sqrt(self.head_dim)a=self.mask[:, :, :seqlen, :seqlen]scores = torch.add(scores,a)scores = F.softmax(scores.float(), dim=-1)scores = scores.type_as(xq)scores = self.attn_dropout(scores)output = torch.matmul(scores, xv)  return outputlock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):if isinstance(args,torch.Tensor):print(name,index,args.shape)global gindexlock.acquire()torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))gindex+=1lock.release()if isinstance(args,tuple):for idx,x in enumerate(args):save_tensor(name,x,index+idx)op_template=Template('''      
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):save_tensor("{{name}}-input",args)    global native1_{{new_name}}             ret=native1_{{new_name}}(*args, **kwargs)save_tensor("{{name}}-output",ret)   return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')for op in dir(torch.Tensor):if op in ["__iter__","shape","dim","unbind","normal_","data","item","numel","save","has_names","data_ptr","untyped_storage","storage_offset","size","stride","triu","half","is_floating_point","to","ones","randint","ones_like"]:continueif getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:continuenew_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")exec(op_template.render(name=op,new_name=new_name))op_template=Template('''      
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):save_tensor("{{name}}-input",args)    global native2_{{new_name}}             ret=native2_{{new_name}}(*args, **kwargs)save_tensor("{{name}}-output",ret) return ret
setattr(torch, '{{name}}', {{new_name}})
''')for op in dir(torch):if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr","untyped_storage","storage_offset","size","stride","triu","is_floating_point","to","ones","randint","full","reshape","ones_like"]:continueif getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:continuenew_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")exec(op_template.render(name=op,new_name=new_name))def hook_backwards(loss, cached):if loss is None:return    def posthook(*args,**kwargs):save_tensor(loss.__class__.__name__,args)def prehook(*args,**kwargs):passloss.register_prehook(prehook)loss.register_hook(posthook)cached.add(loss)for _, child in enumerate(loss.next_functions):if child[0] not in cached:hook_backwards(child[0],cached)def main(flash,bs, n_local_heads, seqlen, head_dim):torch.random.manual_seed(1)q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)q.data.normal_(0, 0.1)k.data.normal_(0, 0.1)v.data.normal_(0, 0.1)q=Variable(q, requires_grad=True).to(device)k=Variable(k, requires_grad=True).to(device)v=Variable(v, requires_grad=True).to(device)gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)loss_func=nn.CrossEntropyLoss().to(device)model=Attention(seqlen,head_dim,flash).half().to(device)optim = torch.optim.SGD([q,k,v], lr=1.1)for i in range(1):output = model(q,k,v)loss=loss_func(output.reshape(-1,head_dim),gt)hook_backwards(loss.grad_fn, cached=set())loss.backward()  optim.step()print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

二.输出

reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/2307.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文件File类的学习

File类 File类创建File实例创建文件删除文件创建目录 Reader小结 File类 在java中,通过java.io.File类来对一个文件进行抽象的描述. 下面我们来看看File类的构造方法:签名说明File(File parent, String child)根据父目录孩子文件路径,创建出一个新的File实例File(String pathn…

springboot整合mybatis-plus模版

1.创建springboot项目 Maven类型Lombok依赖Spring Web 依赖MySQL Driver依赖pom.xml&#xff1a;<?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/…

Springboot+Vue项目-基于Java+MySQL的非物质文化网站设计与实现(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

力扣-2259移除指定数字得到的最大结果

思路&#xff1a; 1. def removeDigit(self, number: str, digit: str) -> str:&#xff1a;这是一个类方法&#xff0c;接受两个参数 number 和 digit&#xff0c;分别表示输入的数字字符串和要移除的数字字符&#xff0c;返回一个字符串。 2. n len(number)&#xff1a…

ui_admin_vue3启动

1、要求node版本16.20.2&#xff0c;小于这个版本npm run dev会报错UnhandledPromiseRejectionWarning: SyntaxError: Unexpected token ‘??‘ 逻辑空赋值(??)是ES2021的语法&#xff0c;node v15.0.0以上才支持逻辑空赋值(??)的语法。之前为了兼容旧代码使用的node版本…

激活虚拟环境.ps1“因为在此系统上禁止运行脚本”解决办法

激活虚拟环境.ps1“因为在此系统上禁止运行脚本”解决办法 1.问题收录 Django激活虚拟环境时遇到的&#xff0c;已解决&#xff0c;作以收录&#xff0c;希望能帮到大家 2.分析问题 核心是Powershell的安全策略&#xff0c;将XX命令视为不安全脚本&#xff0c;不允许执行&…

【火猫TV】意甲:CDK展现自身天赋,真蓝黑军团绝对不会客气

俗话说&#xff1a;树挪死&#xff0c;人挪活。在足坛有很多球员更换球队之后获得了新生&#xff0c;在新球队发挥出了自己的实力&#xff0c;比如从AC米兰租借到亚特兰大的小将德凯特拉雷&#xff08;简称CDK&#xff09;就让红黑军团看走眼。本赛季他在亚特兰大发挥出色&…

制作识货的商品购买页面(注释加讲解)

在制作此页面时运用了浮动&#xff0c;绝对定位&#xff0c;固定定位&#xff0c;相对定位。这些可以让页面整洁美观。 商品购买页面里有很多的商品可大家观看最上面的搜索栏里可以打字下面的&#xff0c;首页&#xff0c;优惠&#xff0c;识物&#xff0c;登录注册都可以进行…

HBase的简单学习三

一 过滤器 1.1相关概念 1.过滤器可以根据列族、列、版本等更多的条件来对数据进行过滤&#xff0c; 基于 HBase 本身提供的三维有序&#xff08;行键&#xff0c;列&#xff0c;版本有序&#xff09;&#xff0c;这些过滤器可以高效地完成查询过滤的任务&#xff0c;带有过滤…

Netty 入门

文章目录 1. 概述1.1 Netty 是什么&#xff1f;1.2 Netty 的作者1.3 Netty 的地位1.4 Netty 的优势 2. Hello World2.1 目标2.2 服务器端2.3 客户端2.4 流程梳理 3. 组件3.1 EventLoop&#x1f4a1; 优雅关闭演示 NioEventLoop 处理 io 事件&#x1f4a1; handler 执行中如何换…

【Python】基础知识(函数与数据容器)

笔者在C语言基础上学习python自用笔记 type() 返回数据类型 name "root" hei 1.8 wei 77 type_hei type(hei) type_wei type(wei) print(type(name)) print(type_hei) print(type_wei)在python中变量是没有类型的&#xff0c;它存储的数据是有类型的。 数据类…

查找两个字符串的最长公共子串

暴力解法 #include <iostream> #include <vector> #include <cstring> using namespace std; string a, b, minn ""; // a和b是我们输入的 // minn存储的是我们最小的那个字符串string cut(int l, int r) {string tmp "";for (int i …

类与对象(二) 构造函数与析构函数

目录 1.类的6个默认成员函数 2.构造函数 2.析构函数 1.类的6个默认成员函数 我们前面讲到了一个空类&#xff0c;也就是类里面没有声明成员&#xff0c;但是空类里面真的什么都没有吗&#xff1f;不然&#xff0c;任何类在什么都不写时&#xff0c;编译器自动生成以下六个默…

Introducing Meta Llama 3: The most capable openly available LLM to date

要点 今天&#xff0c;我们推出 Meta Llama 3&#xff0c;这是我们最先进的开源大型语言模型的下一代。Llama 3型号将很快在AWS&#xff0c;Databricks&#xff0c;Google Cloud&#xff0c;Hugging Face&#xff0c;Kaggle&#xff0c;IBM WatsonX&#xff0c;Microsoft Azur…

VsCode一直连接不上 timed out

前言 前段时间用VsCode连接远程服务器&#xff0c;正常操作后总是连接不上&#xff0c;折磨了半个多小时&#xff0c;后面才知道原来是服务器设置的问题&#xff0c;故记录一下&#xff0c;防止后面的小伙伴也踩坑。 我使用的是阿里云服务器&#xff0c;如果是使用其他平台服务…

JMeter组件--配置元件--响应断言

响应断言&#xff08;Response Assertion&#xff09; 当响应中有明显的业务标志时&#xff0c;我们可以采用该断言器检测响应报文返回的特征值&#xff0c;进而判断在业务上是否确定&#xff1b;使用频率非常高&#xff0c;大部分场景均可以使用该断言器。 右键 >>>…

CCS项目持续集成

​ 因工作需要&#xff0c;用户提出希望可以做ccs项目的持续集成&#xff0c;及代码提交后能够自动编译并提交到svn。调研过jenkins之后发现重新手写更有性价比&#xff0c;所以肝了几晚终于搞出来了&#xff0c;现在分享出来。 ​ 先交代背景&#xff1a; 1. 代码分两部分&am…

数据结构8:队列

文章目录 Queue.h 实现文件Queue.c 测试文件test.c #头文件 Queue.h #pragma once#include<stdio.h> #include<stdlib.h> #include<assert.h> #include<stdbool.h>typedef int QListDataType;typedef struct QListNode {QListDataType val;struct QLi…

IPRally巧用Google Kubernetes Engine和Ray改善AI

专利检索平台提供商 IPRally 正在快速发展&#xff0c;为全球企业、知识产权律师事务所以及多个国家专利和商标局提供服务。随着公司的发展&#xff0c;其技术需求也在不断增长。它继续训练模型以提高准确性&#xff0c;每周添加 200,000 条可供客户访问的可搜索记录&#xff0…

Python语言零基础入门——案例实战

目录 一、用户登录系统 二、计算天数 一、用户登录系统 1.功能需求&#xff1a;用户输入用户名、密码后&#xff0c;根据用户是否已经注册&#xff0c;用户是否在黑名单中&#xff0c;提示用户是否登录成功。 2.登录功能 输入用户名输入密码登录验证&#xff1a;①用户是否…