跟踪一个Pytorch Module在训练过程中的内存分配情况

跟踪一个Pytorch Module在训练过程中的内存分配情况

  • 代码
  • 输出

目的:跟踪一个Pytorch Module在训练过程中的内存分配情况
方法:
1.通过pre_hook module的来区分module的边界
2.通过__torch_dispatch__拦截所有的aten算子,计算在该算子中新创建tensor的总内存占用量
3.通过tensor.data_ptr()为tensor去重,表示一块独立的内存

代码


import numpy as np
import torch
from torch.nn import Module, Linear
import torch.nn as nn
from torch.optim import Adam,SGD
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
import time@dataclass
class _ProfilerState:cls: Anyobject: Any = Nonecurrent_module=None
tesor_cache=set()def get_current_mem():global current_moduleprint(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')current_module=Noneclass InputDescriptor:def __init__(self) -> None:self.total_input_size=0def _save_var(self,v):class_name=v.__class__.__name__if class_name in ["Tensor","Parameter"]:global tesor_cachetensorid=v.data_ptr()if v.device.type!="cuda":return            if tensorid not in tesor_cache:tesor_cache.add(tensorid)sz=v.numel()*v.element_size()print(v.shape,v.dtype)self.total_input_size += szif class_name=="Parameter" and v.grad is not None:                tensorid=v.grad.data_ptr()if tensorid not in tesor_cache:tesor_cache.add(tensorid)sz=v.grad.numel()*v.grad.element_size()print("grad",v.grad.shape,v.grad.dtype)self.total_input_size += szelif class_name in ["list","tuple"]:for t in v:self._save_var(t)else:passdef save_vars(self,ret,*args,**kwargs):for arg in args:self._save_var(arg)        for k,v in kwargs.items():self._save_var(v)self._save_var(ret)global current_module        if current_module is None:current_module={"name":"Other","size":[]}current_module["size"].append(self.total_input_size)# 对象和类名缓存
object_cache = {}
class_name_count = {}def get_unique_name(class_name, obj_id):# 生成唯一的对象名称if class_name not in class_name_count:class_name_count[class_name] = 0uid = f"{class_name}_{obj_id}"if uid not in object_cache:class_name_count[class_name] += 1object_cache[uid] = {"idx": class_name_count[class_name]}return f'{class_name}-{object_cache[uid]["idx"]}'def initialize_module_attributes(module):# 初始化模块属性if not hasattr(module, 'uuid'):module.uuid = get_unique_name(module.__class__.__name__, id(module))if not hasattr(module, 'backward_mem'):module.backward_mem = []if not hasattr(module, 'forward_mem'):module.forward_mem = []def pre_backward_hook(module, grad_input):# 反向传播前的钩子函数initialize_module_attributes(module)global current_moduleif current_module is not None and np.sum(current_module["size"])>0:print(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')module.backward_mem.clear()current_module={"name":f"backward-{module.uuid}","size":module.backward_mem}def post_backward_hook(module, grad_input, grad_output):# 反向传播后的钩子函数initialize_module_attributes(module)def pre_forward_hook(module, input):# 前向传播前的钩子函数initialize_module_attributes(module)global current_moduleif current_module is not None and np.sum(current_module["size"])>0:print(f'[INFO]{current_module["name"]}:{np.sum(current_module["size"])}')module.forward_mem.clear()current_module={"name":f"forward-{module.uuid}","size":module.forward_mem}def post_forward_hook(module, input, output):# 前向传播后的钩子函数initialize_module_attributes(module)def register_forward_hooks(module):# 注册反向传播钩子module.register_forward_pre_hook(pre_forward_hook)module.register_forward_hook(post_forward_hook)def register_backward_hooks(module):# 注册反向传播钩子module.register_full_backward_pre_hook(pre_backward_hook)module.register_full_backward_hook(post_backward_hook)class HookModel(object):def __init__(self, model):output_dict = {}self.get_submodule_recrusicve(model, "", output_dict)for name, module in output_dict.items():if name.endswith("Sequential"):continueregister_forward_hooks(module)register_backward_hooks(module)def get_submodule_recrusicve(self,module, prefix, output_dict):prefix = prefix + "/" + type(module).__name__output_dict[prefix] = modulefor name, submodule in module.named_children():self.get_submodule_recrusicve(submodule, f"{prefix}[{name}]", output_dict)class TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef __torch_dispatch__(self, func, types, args=(), kwargs=None):if kwargs is None:kwargs = {}  ret= func(*args, **kwargs)desc=InputDescriptor()desc.save_vars(ret,*args,**kwargs)if desc.total_input_size>0:print(f"{func.__name__}:{desc.total_input_size}")return retclass TorchDebugDumper:_CURRENT_Dumper = Nonedef __init__(self):self.p= _ProfilerState(TorchDumpDispatchMode)def __enter__(self):assert TorchDebugDumper._CURRENT_Dumper is NoneTorchDebugDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDebugDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.objectclass FeedForward(Module):def __init__(self,hidden_size,ffn_size):super().__init__()self.fc = nn.Sequential(Linear(in_features=hidden_size, out_features=ffn_size,bias=False),nn.ReLU(),Linear(in_features=ffn_size, out_features=ffn_size*2,bias=False),nn.Dropout(0.5),Linear(in_features=ffn_size*2, out_features=hidden_size,bias=False),)self.norm = nn.LayerNorm(normalized_shape=hidden_size, elementwise_affine=False)def forward(self, x):return x + self.fc(self.norm(x))def main():model=FeedForward(100,128)model=model.float().cuda()model.train()obj=HookModel(model)global current_modulewith TorchDebugDumper():opt=Adam(model.parameters(),lr=0.001)input=torch.randn(1,100).float().cuda()output=model(input)get_current_mem()loss=-torch.log(output.sum())opt.zero_grad()loss.backward()get_current_mem()current_module=Noneopt.step()    get_current_mem()num_model_params = sum(p.numel() for p in model.parameters())print(f"[INFO]Number of model parameters: {num_model_params}")
main()

输出

torch.Size([1, 100]) torch.float32
_to_copy.default:400
[INFO]Other:400
torch.Size([1, 100]) torch.float32
torch.Size([1, 1]) torch.float32
torch.Size([1, 1]) torch.float32
native_layer_norm.default:408
[INFO]forward-LayerNorm-1:408
torch.Size([128, 100]) torch.float32
t.default:51200
[INFO]forward-Linear-1:51200
torch.Size([256, 128]) torch.float32
t.default:131072
torch.Size([1, 256]) torch.float32
mm.default:1024
[INFO]forward-Linear-2:132096
torch.Size([1, 256]) torch.float32
native_dropout.default:1024
[INFO]forward-Dropout-1:1024
torch.Size([100, 256]) torch.float32
t.default:102400
torch.Size([1, 100]) torch.float32
add.Tensor:400
[INFO]forward-Linear-3:102800
torch.Size([]) torch.float32
log.default:4
torch.Size([]) torch.float32
neg.default:4
torch.Size([]) torch.float32
neg.default:4
torch.Size([]) torch.float32
div.Tensor:4
[INFO]Other:16
torch.Size([100, 256]) torch.float32
mm.default:102400
torch.Size([1, 256]) torch.float32
mm.default:1024
[INFO]backward-Linear-3:103424
torch.Size([128, 100]) torch.float32
mm.default:51200
[INFO]backward-Linear-1:51200
torch.Size([128, 100]) torch.float32
zeros_like.default:51200
torch.Size([128, 100]) torch.float32
zeros_like.default:51200
torch.Size([256, 128]) torch.float32
zeros_like.default:131072
torch.Size([256, 128]) torch.float32
zeros_like.default:131072
torch.Size([100, 256]) torch.float32
zeros_like.default:102400
torch.Size([100, 256]) torch.float32
zeros_like.default:102400
torch.Size([128, 100]) torch.float32
torch.Size([256, 128]) torch.float32
torch.Size([100, 256]) torch.float32
_foreach_sqrt.default:284672
[INFO]Other:854016
[INFO]Number of model parameters: 71168

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/20607.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java八股文面试全套真题

Java八股文面试全套真题 一、Redis1.1、你在最近的项目中哪些场景使用了redis呢?1.2、缓存穿透1.3、布隆过滤器1.4、缓存击穿1.5、缓存雪崩1.6、redis做为缓存,mysql的数据如何与redis进行同步呢?(双写一致性)1.6.1、读…

进程与线程(一)

进程与线程(一) 理解什么是并发编程进程的相关概念什么是进程对比进程和程序理解进程是一个独立的可调度的任务理解进程是程序执行和资源管理的最小单位进程状态转换图进程的种类 进程相关命令进程状态标志ps命令-aux:-axj:(可以查看到进程的PPID)pstree…

浅析R16移动性增强那些事儿(DAPS/CHO/MRO)

R16移动性增强相关技术总结 Dual Active Protocol Handover Dual Active Protocol Handover意为双激活协议栈切换,下文简称DAPS切换,DAPS切换的核心思想是切换过程中,在UE成功连接到目标基站前继续保持和源基站的连接和数据传输,…

示波器眼图怎么看

目录 什么是眼图? 怎么看? 眼图的电压幅度(Y轴) 眼睛幅度和高度 信噪比 抖动 上升时间和下降时间 眼宽 什么是眼图? 眼图(Eye Diagram)是一种用于分析高速数字信号传输质量的重要工具。通…

OpenJDK优化技术之标量替换(Scalar Replacement)

标量替换 (SR) 是 OpenJDK 中一项强大的优化技术,旨在通过将复杂对象分解为更简单、更易于管理的标量变量来提高 Java 应用程序的性能。 1.前言 OpenJDK JVM 有两个即时编译器,C1 和 C2。C2 是一种应用许多优化来生成非常高效的编译版本程序的编译器。…

【全开源】Java共享台球室无人系统支持微信小程序+微信公众号+H5

智能引领台球新体验 一、引言:共享经济的新篇章 在共享经济的大潮中,各类共享服务层出不穷,为人们的生活带来了极大的便利。共享台球室作为其中的一员,以其独特的魅力吸引了众多台球爱好者的目光。而今天,我们要介绍…

【JavaScript脚本宇宙】JavaScript日期处理神器: 6款顶级库解析

提升编程效率:六个强大的JavaScript日期时间库介绍 前言 在信息化社会,日期和时间的处理是任何编程语言必不可少的部分。本文将介绍六个优秀的JavaScript日期和时间库,这些库各有特色,可以应对多样的使用场景。 欢迎订阅专栏&am…

RAG检索增强生成

Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks Lewis P, Perez E, Piktus A, et al. Retrieval-augmented generation for knowledge-intensive nlp tasks[J]. Advances in Neural Information Processing Systems, 2020, 33: 9459-9474.

【通信专题】I2C上拉电阻计算方法

I2C 通信总线是电子设计中常见的总线之一,由于 I2C 的硬件芯片内部为开漏输出,所以要求在外部增加一个上拉电阻,总线上拉电阻的选取受多个因素的影响,因此如何计算 I2C 总线的上拉电阻阻值成为硬件工程师在使用 I2C总统时需要关注的话题。 从本质上讲: I2C 总线电容和上升…

算法第三天力扣第69题:X的平方根

69. x 的平方根 (可点击下面链接或复制网址进行做题) https://leetcode.cn/problems/sqrtx/https://leetcode.cn/problems/sqrtx/ 给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。 注意:不允许使用任何内…

密码和密钥的联系与区别

密码和密钥是两个非常重要的概念,但容易混淆这两者,以下内容介绍了它们的联系和区别: 一、定义 密码(Password),在日常语境中,通常指的是个人为了验证自己的身份而设置的一段秘密的字符序列&am…

动态规划:优化问题求解的艺术

引言: 在计算机科学和数学中,动态规划是一种强大的算法设计技术,用于解决具有重叠子问题和最优子结构特性的复杂问题。动态规划不仅可以简化问题的求解过程,还能显著提高效率。本文将介绍动态规划的基本概念、工作原理、算法设计步…

周末总结(2024/06/01)

工作 人际关系核心实践: 要学会随时回应别人的善意。执行时间控制在5分钟以内 坚持每天早会打招呼 遇到接不住的话题时拉低自己,抬高别人(无阴阳气息) 工作上的要点 现状(接受破烂现状,改变状态) - 我很不满意现在的…

基于Qt GraphicView 解析 CIM/G 电力接线图文件

本文讲述了如何使用Qt的框架来渲染展示标准的CIM/G格式的图形文件,也就是公用信息模型(common information model,CIM)中的G文件部分的内容。这是一种电力系统图形的交换规则,用于电网图形交换。 [by amjieker] CIM/G …

【自动驾驶】点与向量从ego系转odometry系

1.点从ego系转odometry系(ego -> odometry) struct Point {float x;float y;float angle; }; Point trans; // is the odom to ego transform Point odom_coord; is the odom coord Point ego_coord; is the ego coordfloat odom_coord.x = (ego_coord.x - trans.x) * st…

Selenium番外篇文本查找、元素高亮、截图、无头运行

Selenium根据文本查找元素 ​ python def find_element_with_text(self, loc, attribute, text):try:WebDriverWait(self.driver, 5).until(EC.all_of(EC.text_to_be_present_in_element_attribute(loc, attribute, text)))element self.driver.find_element(*loc)if isinsta…

C++青少年简明教程:break语句、continue语句

C青少年简明教程:break语句、continue语句 break语句 只能用在switch语句和循环语句(for循环、while循环和do-while循环)中。作用:跳出switch语句或提前终止循环。 break语句的基本语法如下: break; break语句的示例…

Nutanix在.NEXT大会宣布AI战略升级:GPT-in-a-Box 2.0集成NVIDIA,强化企业级AI应用支持

Nutanix在巴塞罗那举行的.NEXT大会上宣布了一系列新动向,旨在借助与思科的合作、Broadcom收购VMware、生成式人工智能(GenAI)的兴起、容器化技术、PostgreSQL数据库的广泛应用以及绿色能源倡议,进一步扩大其在人工智能领域的影响力…

macbook配置前端环境:深度解析与实战指南

macbook配置前端环境:深度解析与实战指南 在数字时代的浪潮中,前端开发已成为构建互动、生动且富有吸引力的用户界面的关键。而MacBook,以其卓越的性能和稳定的系统,成为前端开发者们的首选工具。然而,对于初学者或新…

C# WinForm —— 26 ImageList 介绍

1. 简介 图片集合,用于存储图像的资源,并在关联控件中显示出来 可以通过 索引、键名 访问每张图片 没有事件 2. 属性 属性解释(Name)控件ID,在代码里引用的时候会用到,一般以 imgList 开头ClolorDepth用于呈现图像的颜色数,默…