pytorch 训练过程内存泄露/显存泄露debug记录:dataloader和dataset导致的泄露

背景

微调 mask-rcnn 代码,用的是 torchvision.models.detection.maskrcnn_resnet50_fpn 代码,根据该代码的注释,输入应该是:
images, targets=None
(List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]

所以我写的 dataset 是这样的:

def _load_ann(self):transformed_anns = {'boxes': boxes : List[List],'labels': categories: List[int],'masks': masks: List[str]}self.anns[filename] = transformed_annsdef __getitem__(self, item) -> (Tensor, Optional[Dict[str, Tensor]]):img_name = self.img_list[item]img = cv2.imread(os.path.join(self.root_path, self.split, img_name))if self.split == 'train':return img, self.anns[img_name]

大概思路是:先把所有的标注信息读入内存,然后按照 img_name 把标注信息(也就是 target )取出来

这里有个令人纠结的地方:__get_item__ 返回的到底是什么格式的数据?一开始我是直接把 boxes labels masks 都直接写成 tensor 返回,但是一次性把所有的 masks 都读到内存,太大了。再加上看了这个 pytorch内存泄露-dataloader - 知乎 ,这里建议 dataset 的 __get_item__ 返回的都是 python 的基础数据类型,所以我就改成了上面的样子。其实返回什么类型的都行,只要在 dataloader 的 collate_fn 方法里面都转成可以送入模型的数据形式就行了。

因为 dataset 是上面的写法,所以对应的 collate_fn 写法是:

def collate_fn(datas: List[Tuple[Tensor, Dict]]):imgs = []targets = []for data in datas:img, target = dataimgs.append(transforms.ToTensor()(img))target['boxes'] = torch.tensor(target['boxes'], dtype=torch.float)target['labels'] = torch.tensor(target['labels'], dtype=torch.int64)masks = target['masks']masks = [cv2.imread(mask, 0) for mask in masks]masks = np.stack(masks, axis=0)masks = masks / 255masks = masks.astype(np.uint8)target['masks'] = torch.from_numpy(masks)targets.append(target)return imgs, targets

错误排查及解决方法

把所有的数据送入 model 的代码都注释掉,只保留如下代码:

for e in range(epoch):for i, (imgs, targets) in enumerate(train_dataloader):imgs = [img.to(device) for img in imgs]targets = [_to_device(target, device) for target in targets]

watch -n 1 nvidia-smi 监控显存占用,发现一直在涨。毫无疑问肯定是 dataloader 导致的显存泄露 😭

然后就是排查,到底是谁?是谁想害朕??

排查方法是:分别注释掉 imgs / boxes / labels / masks ,观察注释掉谁的时候不会显存泄露。

发现,是 masks 导致的内存泄露。

但是这很怪啊,明明 masks 和 imgs 是一样的数据类型,为什么前者会显存泄露,但是后者不会?于是我把 masks 单独拿出来,像 imgs 一样放在 list 里面,不会内存泄露。但是一旦把 imgs 嵌套放在 targets 这个 dict 里面,就会显存泄露 orz

于是,既然是 masks 没有释放,所以我加一句:

for e in range(epoch):for i, (imgs, targets) in enumerate(train_dataloader):imgs = [img.to(device) for img in imgs]targets = [_to_device(target, device) for target in targets]# ... 传入模型的计算for target in targets:del target['masks']

但是没用,还是泄露。然后查了 pytorch 怎么释放 tensor,发现要主动调用 torch.cuda.empty_cache() 才会释放,所以我又加了一句:

for e in range(epoch):for i, (imgs, targets) in enumerate(train_dataloader):imgs = [img.to(device) for img in imgs]targets = [_to_device(target, device) for target in targets]# ... 传入模型的计算for target in targets:del target['masks']torch.cuda.empty_cache()

这回没有显存泄露了。

但是出现了新的问题,在 epoch=2 的时候报错 targets 没有 masks 这个 key;然后我 debug 发现,由 dataloader 取到的数据 label 和 boxes 在 collate_fn 之前就已经是 tensor 状态了,再往前倒,发现 dataset.anns 里面的数据居然被改了!这实在是太荒谬了。

所以我把 __get_item 改成:

def __getitem__(self, item) -> (Tensor, Optional[Dict[str, Tensor]]):img_name = self.img_list[item]img = cv2.imread(os.path.join(self.root_path, self.split, img_name))if self.split == 'train':return img, deepcopy(self.anns[img_name])

这样就没问题了

总结

  1. 查找内存泄露/显存泄露的位置:
    • 把数据送入模型的代码全部注释掉,观察显存是否上涨;上涨说明内存泄露出现在 dataloader(出现在非 dataloader 地方的最常见的显存泄露原因是,loss 打印/统计的时候没有写 loss.item()
    • 把不同的 data 组成部分注释掉,观察具体是哪个 data 导致的内存泄露
  2. pytorch 释放内存的方法:把 tensor 读到 gpu 就会有显存占用,一般可以自动释放,但是显存泄露的时候就没法释放。找到没有及时释放的代码位置,然后首先 del tensor 标记删除,随后需要调用 torch.cuda.empty_cache() 才能真正释放
  3. dataset 的 __get_item__ 方法注意,如果要返回内部维护的 list 类型的数据的话,不要直接返回该数据切片,而是返回 deepcopy() 防止内部维护的数据被外部修改

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大数据】-- 部署 Flink kubernetes operator

目录 1.说明 1.1 版本 1.2 kubernetes 环境 1.3 参考 2.安装步骤 2.1 安装本地 kubernetes 环境

Oracle 使用 CONNECT_BY_ROOT 解锁层次结构洞察:在 SQL 中导航数据关系

CONNECT_BY_ROOT 是一个在 Oracle 数据库中使用的特殊函数,它通常用于在层次查询中获取根节点的值。在使用 CONNECT BY 子句进行层次查询时,通过 CONNECT_BY_ROOT 函数,你可以在每一行中获取根节点的值,而不仅仅是当前行的值。 假…

Vue3 实现产品图片放大器

Vue3 实现类似淘宝、京东产品详情图片放大器功能 环境&#xff1a;vue3tsvite 1.创建picShow.vue组件 <script lang"ts" setup> import {ref, computed} from vue import {useMouseInElement} from vueuse/core/*获取父组件的传值*/ defineProps<{images:…

从支付或退款之回调处理的设计,看一看抽象类的使用场景

一、背景 抽象类&#xff0c;包含抽象方法和实例方法&#xff0c;抽象方法待继承类去实例化&#xff0c;正是利用该特性&#xff0c;以满足不同支付渠道的差异化需求。 我们在做多渠道支付的时候&#xff0c;接收支付或退款的回调报文&#xff0c;然后去处理。这就意味着&…

【python 深度学习】解决遇到的问题

目录 一、RuntimeError: module compiled against API version 0xc but this version of numpy is 0xb 二、AttributeError: module ‘tensorflow’ has no attribute ‘flags’ 三、conda 更新 Please update conda by running 四、to search for alternate channels that…

Tomcat部署SpringBoot项目

1.修改打包方式 pom.xml 里 加上 <packaging>war</packaging>2.移除内嵌的Tomcat <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-tomcat</artifactId><scope>provided</scope…

Java-jar和war包的区别

jar包和war包的区别&#xff1a; 1、war是一个web模块&#xff0c;其中需要包括WEB-INF&#xff0c;是可以直接运行的WEB模块&#xff1b;jar一般只是包括一些class文件&#xff0c;在声明了Main_class之后是可以用java命令运行的。 2、war包是做好一个web应用后&#xff0c;通…

Kubernetes 调度 约束

调度约束 Kubernetes 是通过 List-Watch 的机制进行每个组件的协作&#xff0c;保持数据同步的&#xff0c;每个组件之间的设计实现了解耦。 用户是通过 kubectl 根据配置文件&#xff0c;向 APIServer 发送命令&#xff0c;在 Node 节点上面建立 Pod 和 Container。 APIServer…

腾讯云轻量应用服务器和云服务器有什么区别?

腾讯云轻量服务器和云服务器有什么区别&#xff1f;为什么轻量应用服务器价格便宜&#xff1f;是因为轻量服务器CPU内存性能比云服务器CVM性能差吗&#xff1f;轻量应用服务器适合中小企业或个人开发者搭建企业官网、博客论坛、微信小程序或开发测试环境&#xff0c;云服务器CV…

饿了么输入框限制只能输入数字,并且保留小数

可以使用饿了么ui中的input-number组件实现输入框只能输入数字&#xff0c;这样就不能输入数字以外的&#xff0c;controls隐藏输入框左右俩边的加减按钮&#xff0c;precision小数点保留多少位&#xff0c;2则是俩位&#xff0c;但是会导致默认值为0.00的情况&#xff0c;俩种…

开源数据库Mysql_DBA运维实战 (DDL语句)

DDL DDL语句 数据库定义语言&#xff1a;数据库、表、视图、索引、存储过程. 例如:CREATE DROP ALTER DDL库 定义库{ 创建业务数据库&#xff1a;CREAATE DATABASE ___数据库名___ ; 数据库名要求{ a.区分大小写 b.唯一性 c.不能使用关键字如 create select d.不能单独使用…

图像识别模型与训练策略

图像预处理 1.需要将图像Resize到相同大小输入到卷积网络中 2.翻转、裁剪、色彩偏移等操作 3.转化为Tensor数据格式 4.对RGB三种颜色通道进行标准化 data_transforms {train: transforms.Compose([transforms.Resize([96, 96]),transforms.RandomRotation(45),#随机旋转&…

unable to write symref for HEAD: Permission denied

今天从gitee上面克隆项目到本地时报错如下 warning: unable to unlink ‘D:/IDEAcode/ruiji1.0/.git/HEAD.lock’: Invalid argument error: unable to write symref for HEAD: Permission denied 解决方法&#xff1a;将要存放项目的文件夹权限修改为完全控制 原先权限&…

GO学习之 接口(Interface)

GO系列 1、GO学习之Hello World 2、GO学习之入门语法 3、GO学习之切片操作 4、GO学习之 Map 操作 5、GO学习之 结构体 操作 6、GO学习之 通道(Channel) 7、GO学习之 多线程(goroutine) 8、GO学习之 函数(Function) 9、GO学习之 接口(Interface) 文章目录 GO系列前言一、什么是…

什么是MVCC

问题描述 对于 MVCC 的理解&#xff0c;我觉得可以先从数据库的三种并发场景说起&#xff1a; 第一种&#xff1a;读读 线程 A 与线程 B 同时在进行读操作&#xff0c;这种情况下不会出现任何并发问题。 第二种&#xff1a;读写 线程 A 与线程 B 在同一时刻分别进行读和写…

W5100S-EVB-PICO 做TCP Server进行回环测试(六)

前言 上一章我们用W5100S-EVB-PICO开发板做TCP 客户端连接服务器进行数据回环测试&#xff0c;那么本章将用开发板做TCP服务器来进行数据回环测试。 TCP是什么&#xff1f;什么是TCP Server&#xff1f;能干什么&#xff1f; TCP (Transmission Control Protocol) 是一种面向连…

十一、结合数字孪生与时间技术进行多维分析设计与实施

大数据可视化中心以主题为分析对象,选择业务分类下的某个主题,可以在数据面板中展示其二维图表,在地图中标记其空间分布,并叠加其相应的二维或三维图层。 1、界面设计 其主界面设计详上图,各部分功能介绍如下: 1.1、主题与图层面板,从上到下,从左到右分别是: ①折…

【1++的数据结构】之二叉搜索树

&#x1f44d;作者主页&#xff1a;进击的1 &#x1f929; 专栏链接&#xff1a;【1的数据结构】 文章目录 一&#xff0c;什么是二叉搜索树二&#xff0c;二叉搜索树的操作及其实现2.1 插入操作及其实现2.2 查找操作及其实现2.3 删除操作及其实现 三&#xff0c;构造及其析构四…

分布式链路追踪概述

分布式链路追踪概述 文章目录 分布式链路追踪概述1.分布式链路追踪概述1.1.什么是 Tracing1.2.为什么需要Distributed Tracing 2.Google Dapper2.1.Dapper的分布式跟踪2.1.1.跟踪树和span2.1.2.Annotation2.1.3.采样率 3.OpenTracing3.1.发展历史3.2.数据模型 4.java探针技术-j…

TOMCAT部署及优化(Tomcat配置文件参数优化,Java虚拟机(JVM)调优)

TOMCAT tomcat &#xff1a;是一个开放源代码的web应用服务器&#xff0c;基于java代码开发的。也可以理解为tomacat就是处理动态请求和基于java代码的页面开发。可以在html当中写入java代码&#xff0c;tomcat可以解析html页面当中的java&#xff0c;执行动态请求&#xff0c;…