Pytorch使用DataLoader, num_workers!=0时的内存泄露

  • 描述一下背景,和遇到的问题:

我在做一个超大数据集的多分类,设备Ubuntu 22.04+i9 13900K+Nvidia 4090+64GB RAM,第一次的训练的训练集有700万张,训练成功。后面收集到更多数据集,数据增强后达到了1000万张。但第二次训练4个小时后,就被系统杀掉进程了,原因是Out of Memory。找了很久的原因,发现内存随着训练step的增加而线性增加,猜测是内存泄露,最后定位到了DataLoader的num_workers参数(只要num_workers=0就没有问题)。

  • 真正原因:

Python(Pytorch)中的list转换成tensor时,会发生内存泄漏,要避免list的使用,可以通过使用np.array来代替list。

  • 解决办法:

自定义DataLoader中的Dataset类,然后Dataset类中的list全部用np.array来代替。这样的话,DataLoader将np.array转换成Tensor的过程就不会发生内存泄露。

  • 下面给两个错误的示例代码和一个正确的代码:(都是我自己犯过的错误)

1.错误的DataLoader加载数据集方法1

# 加载数据
train_data = datasets.ImageFolder(root=TRAIN_DIR_ARG, transform=transform)
valid_data = datasets.ImageFolder(root=VALIDATION_DIR, transform=transform)
test_data = datasets.ImageFolder(root=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

2.错误的DataLoader加载数据集方法2(重写了Dataset方法)


class CustomDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.image_paths = []self.labels = []# 遍历数据目录并收集图像文件路径和对应的标签classes = os.listdir(data_dir)for i, class_name in enumerate(classes):class_dir = os.path.join(data_dir, class_name)if os.path.isdir(class_dir):for image_name in os.listdir(class_dir):image_path = os.path.join(class_dir, image_name)self.image_paths.append(image_path)self.labels.append(i)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]label = self.labels[idx]# # 在需要时加载图像image = Image.open(image_path)if self.transform:image = self.transform(image)return image, labeltrain_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=18)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=False)

3.重写Dataset的正确方法(重写了Dataset方法,list全部转成np.array)

class CustomDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.image_paths = []  # 使用Python列表self.labels = []  # 使用Python列表# 遍历数据目录并收集图像文件路径和对应的标签classes = os.listdir(data_dir)for i, class_name in enumerate(classes):class_dir = os.path.join(data_dir, class_name)if os.path.isdir(class_dir):for image_name in os.listdir(class_dir):image_path = os.path.join(class_dir, image_name)self.image_paths.append(image_path)  # 添加到Python列表self.labels.append(i)  # 添加到Python列表# 转换为NumPy数组,这里就是解决内存泄露的关键代码self.image_paths = np.array(self.image_paths)self.labels = np.array(self.labels)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]label = self.labels[idx]# 在需要时加载图像image = Image.open(image_path)if self.transform:image = self.transform(image)# 将图像数据转换为NumPy数组image = np.array(image)return image, labeltrain_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=18)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/97103.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux命令定位与查找:which、whereis和find的用法详解

文章目录 Linux命令的定位与查找1. 简介Linux路径环境变量命令行和Shell 2. which命令which命令的作用使用which命令定位可执行文件多个可执行文件的定位which命令的选项及其使用 3. whereis命令whereis命令的作用使用whereis命令查找二进制文件查找源代码文件whereis命令的选项…

H5+Css3文本溢出添加省略号(包括插件)

一、单行 溢出隐藏 添加省略号 p{overflow: hidden;text-overflow:ellipsis;white-space: nowrap; }二、多行 溢出隐藏 省略号 p{display: -webkit-box;-webkit-box-orient: vertical;/*设置省略号在容器第四行文本后*/-webkit-line-clamp: 4; overflow: hidden; }局限性&…

Holographic MIMO Surfaces (HMIMOS)以及Reconfigurable Holographic Surface(RHS)仿真

这里写目录标题 Simulation setupchatgpt帮我总结代码 Holographic MIMO Surfaces (HMIMOS)以及Reconfigurable Holographic Surface(RHS)仿真: Simulation setup In this section, we evaluate the performance of …

Git 学习笔记 | 安装 Git 及环境配置

Git 学习笔记 | 安装 Git 及环境配置 Git 学习笔记 | 安装 Git 及环境配置安装 Git配置 Git查看配置 Git 学习笔记 | 安装 Git 及环境配置 安装 Git 官方网站:https://git-scm.com/ 官网下载太慢,我们可以使用淘宝镜像下载:https://regist…

信号量机制之整型信号量,记录型信号量

1.信号量机制 用户进程可以通过使用操作系统提供的一对原语来对信号量进行操作,从而很方便的实现了进程互斥、进程同步。 1.信号量 信号量其实就是一个变量(可以是一个整数,也可以是更复杂的记录型变量),可以用一个信号量来表示…

DRM全解析 —— CRTC详解(4)

接前一篇文章:DRM全解析 —— CRTC详解(3) 本文继续对DRM中CRTC的核心结构struct drm_crtc的成员进行释义。 3. drm_crtc结构释义 (21)struct drm_object_properties properties /** properties: property tracking …

网络中的一些基本概念

数据共享本质是网络数据传输 ,即计算机之间通过网络来传输数据,也称为 网络通信 。 根据网络互连的规模不同,可以划分为局域网和广域网。 局域网 LAN 局域网,即 Local Area Network ,简称 LAN 。 Local 即标识了局…

Lua系列文章(1)---Lua5.4参考手册学习总结

windows系统上安装lua,下载地址: Github 下载地址:https://github.com/rjpcomputing/luaforwindows/releases 可以有一个叫SciTE的IDE环境执行lua程序 1 – 简介 Lua 是一种强大、高效、轻量级、可嵌入的脚本语言。 它支持过程编程, 面向对…

【C语言】结构类型的定义和使用

目录 1.结构体(struct)类型 2.结构标记 3.typedef 4.定义结构数组的方法 5.调用结构数组的方法 6.将结构体传入函数 7.结构体使用实例 1.结构体(struct)类型 在C语言中,结构体(struct&#xf…

前端el-select 单选和多选

el-select单选 <el-form-item label"部门名称" prop"departId"><el-select v-model"dataForm.departId" placeholder"请选择" clearable:style{ "width": "100%" } :multiple"false" filtera…

接口自动化测试框架(pytest+allure+aiohttp+ 用例自动生成)

近期准备优先做接口测试的覆盖&#xff0c;为此需要开发一个测试框架&#xff0c;经过思考&#xff0c;这次依然想做点儿不一样的东西。 接口测试是比较讲究效率的&#xff0c;测试人员会希望很快能得到结果反馈&#xff0c;然而接口的数量一般都很多&#xff0c;而且会越来越…

【ARM CoreLink 系列 5 -- CI-700 控制器介绍 】

文章目录 1.1 什么是 CI-700?1.1.1 关于 CI-7001.1.2 CI-700 特点1.2 全局配置参数1.2.1 寻址能力1.3 组件和配置1.3.1 CI-700 互联的结构1.3.2 Crosspoint(XP)1.3.3 外部接口1.4 组件(Components)1.1 什么是 CI-700? CI-700是一种AMBA 5 CHI互连,具有可定制的网状拓扑结构…

<HarmonyOS第一课>ArkTS开发语言介绍——闯关习题及答案

判断题 1.循环渲染ForEach可以从数据源中迭代获取数据&#xff0c;并为每个数组项创建相应的组件。&#xff08; 对 &#xff09; 2.Link变量不能在组件内部进行初始化。&#xff08; 对 &#xff09; 单选题 1.用哪一种装饰器修饰的struct表示该结构体具有组件化能力&#…

Maven 下载安装配置

Maven 下载安装配置 下载 maven maven 官网&#xff1a;https://maven.apache.org/ maven 下载页面&#xff1a;https://maven.apache.org/download.cgi 安装 maven 将下载的apache-maven.zip文件解压到安装目录 将加压后的apache-maven目录改名为maven maven 配置环…

Kafka 高可用

正文 一、高可用的由来 1.1 为何需要Replication 在Kafka在0.8以前的版本中&#xff0c;是没有Replication的&#xff0c;一旦某一个Broker宕机&#xff0c;则其上所有的Partition数据都不可被消费&#xff0c;这与Kafka数据持久性及Delivery Guarantee的设计目标相悖。同时Pr…

threejs 透明贴图,模型透明,白边

问题 使用Threejs加载模型时&#xff0c;模型出现了上面的问题。模型边缘部分白边&#xff0c;或者模型出现透明问题。 原因 出现这种问题是模型制作时使用了透明贴图。threejs无法直接处理贴图。 解决 场景一 模型有多个贴图时&#xff08;一个透贴和其他贴图&#xff0…

笔记01:随机过程——随机游动

一、伯努利随机过程 1. n次伯努利实验中&#xff08;x1&#xff09;发生的总次数Yn&#xff1a; (二项分布) 2. 伯努利实验中事件第一次发生的时间L1&#xff1a; &#xff08;几何分布&#xff09; 3. n次伯努利实验中事件第k次发生的时间Lk&#xff1a; &#xff08;帕斯卡分…

list的模拟实现

全部代码 #pragma once namespace HQJ {template<class T>struct __list_node//节点类{T __data;__list_node<T>* __prev;__list_node<T>* __next;__list_node(const T& x T())//由于不知道要存储的数据类型&#xff0c;使用匿名对象进行初始化:__data…

HiveServer2 Service Crashes(hiveServer2 服务崩溃)

Troubleshooting Hive | 5.9.x | Cloudera Documentation 原因&#xff1a;别人用的都好好的&#xff0c;我的集群为什么会崩溃&#xff1f; 1.hive分区表太多(这里没有说具体数量。) 2.并发连接太多&#xff0c;我记的以前默认是200个连接 3.复杂的hive查询访问表的的分区…

(一)实现一个简易版IoC容器【手撸Spring】

一、前言 相信大家在看本篇文章的时候&#xff0c;对IoC应该有一个比较清晰的理解&#xff0c;我在这里再重新描述下&#xff1a;它的作用就是实现一个容器将一个个的Bean&#xff08;这里的Bean可以是一个Java的业务对象&#xff0c;也可以是一个配置对象&#xff09;统一管理…