pytorch通过change_current_allocator获取所有的子Module实际的内存占用情况

pytorch通过change_current_allocator获取所有的子Module实际的内存占用情况

  • 1.背景介绍
  • 2.参考链接
  • 3.自己的内存分配器
  • 4.pytorch测试代码

1.背景介绍

  • 目的:需要准确统计pytorch每一层计算所需的设备内存
  • 问题:对齐的原因,直接使用torch.cuda.memory_allocated()并不准确
  • 方法:
    • 设置CUBLAS_WORKSPACE_CONFIG,排除CUBLAS_WORKSPACE的影响
    • 使用torch.cuda.memory.change_current_allocator设置自己的内存分配器
    • 在自己的内存分配器里记录内存分配情况

2.参考链接

  • Using custom memory allocators for CUDA
  • 跟踪一个Pytorch Module在训练过程中的内存分配情况
  • cuBLAS workspaces

3.自己的内存分配器

tee alloc.cc<<-'EOF'
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
#include <assert.h>
#include <unordered_map>
#include <iostream>
#include <mutex>// 内存监视器类
class MemoryMonitor {
public:// 分配内存并记录void* allocate(size_t size) {void* ptr;cudaMalloc(&ptr,size);if (ptr) {std::lock_guard<std::mutex> lock(mtx);allocations[ptr] = size;totalAllocated += size;}return ptr;}// 释放内存并记录void deallocate(void* ptr) {if (ptr) {std::lock_guard<std::mutex> lock(mtx);auto it = allocations.find(ptr);if (it != allocations.end()) {totalAllocated -= it->second;allocations.erase(it);}cudaFree(ptr);}}// 获取当前的总分配大小size_t getTotalAllocated() const {std::lock_guard<std::mutex> lock(mtx);return totalAllocated;}private:std::unordered_map<void*, size_t> allocations; // 存储分配地址和大小的哈希表size_t totalAllocated = 0; // 当前总分配大小mutable std::mutex mtx; // 保护数据结构的互斥锁
};MemoryMonitor monitor;extern "C" {void* my_malloc(ssize_t size, int device, cudaStream_t stream) {return monitor.allocate(size);}void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {monitor.deallocate(ptr);}unsigned long long getTotalAllocated(){return monitor.getTotalAllocated();}
}
EOF
g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC

4.pytorch测试代码


tee torch_mem_stat.py <<-'EOF'
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUBLAS_WORKSPACE_CONFIG']=":0:0"
import ctypes
import numpy as np
import torch
from torch.nn import Module, Linear
import torch.nn as nn
from torch.optim import Adam,SGD
from dataclasses import dataclass
from typing import Any
import time
import torchvision.models as models
import syshook_allocator=int(sys.argv[1])if hook_allocator==1:os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING']='1'lib = ctypes.CDLL('./alloc.so')lib.getTotalAllocated.restype = ctypes.c_ulonglongprint("hook_allocator")new_alloc = torch.cuda.memory.CUDAPluggableAllocator('./alloc.so', 'my_malloc', 'my_free')torch.cuda.memory.change_current_allocator(new_alloc)def get_memory_allocated():if hook_allocator:return lib.getTotalAllocated()else:return torch.cuda.memory_allocated()# 对象和类名缓存
object_cache = {}
class_name_count = {}def is_tensor(val):return isinstance(val, (torch.Tensor, nn.Parameter))def describe_tensor_data(tensor,desc=""):if is_tensor(tensor):desc+=f"[shape({','.join(map(str,list(tensor.shape)))})_dtype({tensor.dtype})]"elif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):desc=describe_tensor_data(t,f"{desc}idx({idx})")else:desc+=f"[dtype({type(tensor)})]"return descdef get_unique_name(class_name, obj_id):# 生成唯一的对象名称if class_name not in class_name_count:class_name_count[class_name] = 0uid = f"{class_name}_{obj_id}"if uid not in object_cache:class_name_count[class_name] += 1object_cache[uid] = {"idx": class_name_count[class_name]}return f'-{object_cache[uid]["idx"]}'def initialize_module_attributes(name,module):# 初始化模块属性if not hasattr(module, 'uuid'):module.uuid = name+get_unique_name(module.__class__.__name__, id(module))if not hasattr(module, 'backward_mem'):module.backward_mem = 0if not hasattr(module, 'forward_mem'):module.forward_mem = 0if not hasattr(module, 'fwd_mem_sz'):module.fwd_mem_sz = Noneif not hasattr(module, 'bwd_mem_sz'):module.bwd_mem_sz = None    def pre_backward_hook(module, grad_input):module.backward_mem=get_memory_allocated()def post_backward_hook(module, grad_input, grad_output):memory_allocated=get_memory_allocated()module.bwd_mem_sz=memory_allocated-module.backward_memrank=0if torch.distributed.is_initialized():rank=torch.distributed.get_rank()    if rank==0:with open("torch_module_mem_info.txt","a+") as f:f.write(f"bwd-{module.uuid}#{module.bwd_mem_sz}#{memory_allocated}#{describe_tensor_data(grad_input)}#{describe_tensor_data(grad_output)}\n")def pre_forward_hook(module, input):   module.forward_mem=get_memory_allocated()def post_forward_hook(module, input, output):memory_allocated=get_memory_allocated()module.fwd_mem_sz=memory_allocated-module.forward_memrank=0if torch.distributed.is_initialized():rank=torch.distributed.get_rank()    if rank==0:    with open("torch_module_mem_info.txt","a+") as f:f.write(f"fwd-{module.uuid}#{module.fwd_mem_sz}#{memory_allocated}#{describe_tensor_data(input)}#{describe_tensor_data(output)}\n")def register_forward_hooks(name,module):initialize_module_attributes(name,module)module.register_forward_pre_hook(pre_forward_hook)module.register_forward_hook(post_forward_hook)def register_backward_hooks(name,module):initialize_module_attributes(name,module)module.register_full_backward_pre_hook(pre_backward_hook)module.register_full_backward_hook(post_backward_hook)class HookModel(object):def __init__(self, model):output_dict = {}self.get_submodule_recrusicve(model, "", output_dict)for name, module in output_dict.items():if name.endswith("Sequential"):continueregister_forward_hooks(name,module)register_backward_hooks(name,module)def get_submodule_recrusicve(self,module, prefix, output_dict):prefix = prefix + "/" + type(module).__name__output_dict[prefix] = modulefor name, submodule in module.named_children():self.get_submodule_recrusicve(submodule, f"{prefix}.{name}", output_dict)class FeedForward(Module):def __init__(self,hidden_size,ffn_size):super().__init__()self.fc = nn.Sequential(Linear(in_features=hidden_size, out_features=ffn_size,bias=False),nn.ReLU(),Linear(in_features=ffn_size, out_features=ffn_size*2,bias=False),nn.Dropout(0.5),Linear(in_features=ffn_size*2, out_features=hidden_size,bias=False),)self.norm = nn.LayerNorm(normalized_shape=hidden_size, elementwise_affine=False)def forward(self, x):return x + self.fc(self.norm(x))def main():model=FeedForward(100,128) model=model.float().cuda()model.train()obj=HookModel(model)opt=Adam(model.parameters(),lr=0.001)input=torch.randn(1,100).float().cuda()with open("torch_module_mem_info.txt","w") as f:f.write("")for i in range(1):output=model(input)loss=-torch.log(output.sum())opt.zero_grad()loss.backward()opt.step()
main()
EOFpython torch_mem_stat.py 0
cat torch_module_mem_info.txt
python torch_mem_stat.py 1
cat torch_module_mem_info.txt

输出

#默认的分配器
fwd-/FeedForward.norm/LayerNorm-1#512#285696#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.0/Linear-1#512#286208#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,128)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.1/ReLU-1#512#286720#idx(0)[shape(1,128)_dtype(torch.float32)]#[shape(1,128)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.2/Linear-2#1024#287232#idx(0)[shape(1,128)_dtype(torch.float32)]#[shape(1,256)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.3/Dropout-1#1536#288768#idx(0)[shape(1,256)_dtype(torch.float32)]#[shape(1,256)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.4/Linear-3#512#288256#idx(0)[shape(1,256)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
fwd-/FeedForward-1#3072#288256#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward-1#0#289792#idx(0)[dtype(<class 'NoneType'>)]#idx(0)[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.4/Linear-3#102400#392192#idx(0)[shape(1,256)_dtype(torch.float32)]#idx(0)[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.3/Dropout-1#512#392192#idx(0)[shape(1,256)_dtype(torch.float32)]#idx(0)[shape(1,256)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.2/Linear-2#131584#522752#idx(0)[shape(1,128)_dtype(torch.float32)]#idx(0)[shape(1,256)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.1/ReLU-1#0#521728#idx(0)[shape(1,128)_dtype(torch.float32)]#idx(0)[shape(1,128)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.0/Linear-1#0#521216#idx(0)[dtype(<class 'NoneType'>)]#idx(0)[shape(1,128)_dtype(torch.float32)]#自定义分配器
fwd-/FeedForward.norm/LayerNorm-1#400#285472#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.0/Linear-1#512#285984#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,128)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.1/ReLU-1#512#286496#idx(0)[shape(1,128)_dtype(torch.float32)]#[shape(1,128)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.2/Linear-2#1024#287008#idx(0)[shape(1,128)_dtype(torch.float32)]#[shape(1,256)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.3/Dropout-1#1280#288288#idx(0)[shape(1,256)_dtype(torch.float32)]#[shape(1,256)_dtype(torch.float32)]
fwd-/FeedForward.fc/Sequential.4/Linear-3#400#287664#idx(0)[shape(1,256)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
fwd-/FeedForward-1#2592#287664#idx(0)[shape(1,100)_dtype(torch.float32)]#[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward-1#0#287676#idx(0)[dtype(<class 'NoneType'>)]#idx(0)[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.4/Linear-3#102400#390076#idx(0)[shape(1,256)_dtype(torch.float32)]#idx(0)[shape(1,100)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.3/Dropout-1#768#390840#idx(0)[shape(1,256)_dtype(torch.float32)]#idx(0)[shape(1,256)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.2/Linear-2#131584#521400#idx(0)[shape(1,128)_dtype(torch.float32)]#idx(0)[shape(1,256)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.1/ReLU-1#0#520376#idx(0)[shape(1,128)_dtype(torch.float32)]#idx(0)[shape(1,128)_dtype(torch.float32)]
bwd-/FeedForward.fc/Sequential.0/Linear-1#0#519864#idx(0)[dtype(<class 'NoneType'>)]#idx(0)[shape(1,128)_dtype(torch.float32)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/46930.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[React 进阶系列] useSyncExternalStore hook

[React 进阶系列] useSyncExternalStore hook 前情提要&#xff0c;包括 yup 的实现在这里&#xff1a;yup 基础使用以及 jest 测试 简单的提一下&#xff0c;需要实现的功能是&#xff1a; yup schema 需要访问外部的 storage外部的 storage 是可变的React 内部也需要访问同…

【AI应用探讨】—粒子群算法(PSO)应用场景

目录 1. 神经网络训练 2. 工程设计 3. 电力系统 4. 数据挖掘 5. 控制工程 6. 机器人路径规划 7. 图像处理 8. 生物信息学 9. 其他领域 1. 神经网络训练 应用场景&#xff1a;粒子群算法可以用于神经网络的权重和阈值的优化&#xff0c;以提高神经网络的性能和预测准确…

产品经理-工作中5大类技术名词解析(19)

在产品经理与开发的团队协作中,如果自己知道一些专业术语,对业务的开展是有帮助的&#xff0c;很多时候,在沟通过程当中,就是因为自己不懂,所以才不知道怎么去做,想要什么样的结果 在力所能及的情况下,平时,多了解一些专业术语,是有好处的 数据结构 数据结构是技术人员将数据进…

【iOS】static、extern、const、auto关键字以及联合使用

目录 前言extern关键字static关键字const关键字 联合使用static和externstatic和constextern和const auto关键字 先了解一下静态变量所在的全局/静态区的特点&#xff1a;【iOS】内存五大分区 前言 上面提到的全局/静态区中存放的是全局变量或静态变量&#xff1a; 全局变量…

人工智能大模型发展的新形势及其省思

作者简介 肖仰华&#xff0c;复旦大学计算机科学技术学院教授、博导&#xff0c;上海市数据科学重点实验室主任。研究方向为知识图谱、知识工程、大数据管理与挖掘。主要著作有《图对称性理论及其在数据管理中的应用》、《知识图谱&#xff1a;概念与技术》&#xff08;合著&a…

C++基础语法:STL之容器(5)--序列容器中的list(二)

前言 "打牢基础,万事不愁" .C的基础语法的学习 引入 序列容器的学习.以<C Prime Plus> 6th Edition(以下称"本书")内容理解 本书中容器内容不多只有几页.最好是有数据结构方面的知识积累,如果没有在学的同时补上 接上一篇C基础语法:STL之容器…

Node:解决Error: error:0308010C:digital envelope routines::unsupported的解决方法

问题描述 在使用vuepress搭建博客的时候&#xff0c;运行项目发现报错了&#xff0c;检查了node的版本是18&#xff0c;之前用的是16或14的版本&#xff0c;现在报&#xff1a;Error: error:0308010C:digital envelope routines::unsupported错误。 查找了一些资料&#xff0…

excel系列(三) - 利用 easyexcel 快速实现 excel 文件导入导出

一、介绍 在上篇文章中&#xff0c;我们介绍了 easypoi 工具实现 excel 文件的导入导出。 本篇我们继续深入介绍另一款更优秀的 excel 工具库&#xff1a;easyexcel 。 二、easyexcel easyexcel 是阿里巴巴开源的一款 excel 解析工具&#xff0c;底层逻辑也是基于 apache p…

HOW - 保证 WebSocket 持续正常连接

一、基于 React 的 WebSocket 下面是一个React版本的WebSocket连接代码示例&#xff0c;展示了如何在React组件中实现WebSocket连接、心跳机制以及自动重连功能。 WebSocketManager.js 首先&#xff0c;我们可以创建一个 WebSocketManager 类来封装WebSocket的逻辑&#xff…

web前端 Vue 框架面试120题(六)

面试题 101 . 如何解决Vuex页面刷新数据丢失 &#xff1f; 参考回答&#xff1a; F5页面刷新&#xff0c;页面销毁之前的资源&#xff0c;重新请求&#xff0c;因此写在生命周期里的vuex数据是重新初始化&#xff0c;无法获取的&#xff0c;这也就是为什么会打印出空的原因。…

解决Oracle SQL语句性能问题——正确使用Hint(Hint概念、场景及具体语法)

10.5. 正确使用Hint 10.5.1. Hint概念及场景 调优SQL语句时,Oracle提供了很多可用的Hint。首先,你应该获取和分析SQL语句的执行计划,看能否通过改写SQL语句或其他方法来进行调整和优化,而不是直接使用Hint方法进行优化,最好在其他方法都确定无效或不合理之后再考虑使用H…

HTTPS 的加密过程 详解

HTTP 由于是明文传输&#xff0c;所以安全上存在以下三个风险&#xff1a; 窃听风险&#xff0c;比如通信链路上可以获取通信内容。篡改风险&#xff0c;比如通信内容被篡改。冒充风险&#xff0c;比如冒充网站。 HTTPS 在 HTTP 与 TCP 层之间加入了 SSL/TLS 协议&#xff0c…

概率论原理精解【4】

文章目录 度量空间概述理论基础定义特点高级概念广泛应用 性质例子应用 柯西数列柯西数列的定义柯西数列的例子 参考文献 度量空间 概述 设 f : R n → R m , f ˙ ( x ) 在 { x : ∣ x − x 0 ∣ < r } 内连续&#xff0c;则当 ∣ t ∣ < r 时&#xff0c; f:R^n\righ…

Spring Cloud LoadBalanced

负载均衡(Load Balance&#xff0c;简称 LB) 是⾼并发, ⾼可⽤系统必不可少的关键组件. 当服务流量增⼤时, 通常会采⽤增加机器的⽅式进⾏扩容, 负载均衡就是⽤来在多个机器或者其他资源中, 按照⼀定的规则合理分配负载. 负载均衡的⼀些实现 就像是eureka中对请求进行轮询的…

Java对象创建过程的解析

Java对象创建过程的解析 1. 类的加载与连接2. 内存分配2.1 分配方式2.2 本地线程缓冲分配&#xff08;TLAB&#xff09; 3. 初始化内存4. 设置对象头 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 对象的创建是一个涉及多个步骤的复杂过程…

Qt:26.Qt项目:贪吃蛇游戏

一、项目功能演示&#xff1a; 开始界面可以点击进入游戏。 点击进入游戏之后&#xff0c;切换到选项界面&#xff0c;该界面可以选择游戏难度&#xff0c;回退&#xff0c;以及查询最近一次游戏得分。 游戏具体界面如下。贴图啥的可以自己换&#xff0c;本人审美不咋行&#x…

[SUCTF 2019]EasySQL1

这是一个简单的SQL注入题&#xff0c;但是因为我的SQL基础约等于0&#xff0c;所以做起来很难。 首先试试引号是否被过滤 可以看到单引号、双引号都被过滤了&#xff0c;试试其他的盲注都不行&#xff0c;基本上可以确定不能用这种方法。 在测试的过程中发现&#xff0c;输入…

RICHTEK立锜科技 WIFI 7电源参考设计

什么是WIFI 7? WiFi 7&#xff08;Wi-Fi 7&#xff09;是下一代Wi-Fi标准&#xff0c;对应的是IEEE 802.11将发布新的修订标准IEEE 802.11be –极高吞吐量EHT&#xff08;Extremely High Throughput &#xff09;。Wi-Fi 7是在Wi-Fi 6的基础上引入了320MHz带宽、4096-QAM、Mu…

oceanbase架构、功能模块、数据存储、特性、sql流转层等概念详解

一、架构图 OceanBase 数据库采用无共享&#xff08;Shared-Nothing&#xff09;分布式集群架构&#xff0c;各个节点之间完全对等&#xff0c;每个节点都有自己的 SQL 引擎、存储引擎、事务引擎&#xff0c;运行在普通 PC 服务器组成的集群之上&#xff0c;具备高可扩展性、高…

cephrgw元数据和数据布局

提示&#xff1a;每个rados object有如下几个组成部分&#xff0c;分别是omap&#xff08;omapheader、omapkey、omapval&#xff09;、xattr、data&#xff0c;相关的CLI command rados getomapheader {radosobjectname} -p {poolname} [--namespace{ns}] rados listomapkeys…