【PyTorch笔记】训练时显存一直增加到 out-of-memory?真相了!

最近用 Pytorch 训模型的过程中,发现总是训练几轮后,出现显存爆炸 out-of-memory 的问题,询问了 ChatGPT、查找了各种文档。。。

在此记录这次 debug 之旅,希望对有类似问题的小伙伴有一点点帮助。

问题描述:

训练过程中,网络结构做了一些调整,forward 函数增加了部分计算过程,突然发现 16G 显存不够用了。

用 nvidia-smi 观察显存变化,发现显存一直在有规律地增加,直到 out-of-memory。

解决思路:

尝试思路1:

计算 loss 的过程中是否使用了 item() 取值,比如:

train_loss += loss.item()

发现我不存在这个问题,因为 loss 是最后汇总计算的。

尝试思路2:

训练主程序中添加两行下面的代码,实测发现并没有用。

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

这两行代码是干啥的?

大白话:设置为 True,意味着 cuDNN 会自动寻找最适合当前配置的高效算法,来获得最佳运行效率。这两行通常一起是哦那个

所以:

  • 如果网络的输入数据在尺度或类型上变化不大,设置 torch.backends.cudnn.benchmark = True 可以增加运行效率;
  • 如果网络的输入数据在每次迭代都变化,比如多尺度训练,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率

尝试思路3:

及时删除临时变量和清空显存的 cache,例如在每轮训练后添加:

torch.cuda.empty_cache()

依旧没有解决显存持续增长的问题,而且如果频繁使用 torch.cuda.empty_cache(),会显著增加模型训练时长。

尝试思路4:

排查显存增加的代码位置,既然是增加了部分代码导致的显存增加,那么问题肯定出现在这部分代码中。

为此,可以逐段输出显存占用量,确定问题点在哪。

举个例子:

print("训练前:{}".format(torch.cuda.memory_allocated(0)))
train_epoch(model,data)
print("训练后:{}".format(torch.cuda.memory_allocated(0)))
eval(model,data)
print("评估后:{}".format(torch.cuda.memory_allocated(0)))

最终方案:

最终发现的问题是:我在模型中增加了 register_buffer

self.register_buffer("positives", torch.randn(1, 256))
self.register_buffer("negatives", torch.randn(256, self.num_negatives))

register_buffer 注册的是非参数的 Tensor,它只是被保存在模型的状态字典中,并不会进行梯度计算啊。

为了验证这一点,还打印出来验证了下:

# for name, param in model.named_parameters():
for name, param in model.named_buffers():print(name, param.shape, param.requires_grad)# 输出如下:
positives torch.Size([1, 256]) False
negatives torch.Size([256, 20480]) False

但是这个 buffer 却是导致显存不断增加的罪魁祸首。

为此,赶紧把和 buffer 相关的操作放在 torch.no_grad() 上下文中,问题解决!

@torch.no_grad()
def dequeue_samples(self, positives, negatives):if positives.shape[0] > 0:self.positives = 0.99*self.positives + 0.01*positives.mean(0, keepdim=True)self.negatives[:, self.ptr:self.ptr+negatives.shape[1]] = F.normalize(negatives, dim=0)with torch.no_grad():keys = F.normalize(self.positives.clone().detach(), dim=1).expand(cur_positives.shape[0], -1)negs = self.negatives.clone().detach()

结论:

如果是训练过程中显存不断增加,问题大概率出现在 forward 过程中,可以通过尝试思路4逐步排查出问题点所在,把不需要梯度计算的操作放在 torch.no_grad() 上下文中。

如果本文对你有帮助,欢迎点赞收藏备用!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/51435.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql的唯一索引和普通索引有什么区别

在MySQL中,唯一索引(UNIQUE Index)和普通索引(普通索引,也称为非唯一索引)有一些关键的区别。以下是它们的比较以及性能分析: 唯一索引与普通索引的区别 唯一性: 唯一索引&#xff…

也来聊一聊反复开关空调是否更费电

文章目录 为了制造噱头而刻意开展的毫无实际价值的实验空调制冷的基本原理空调主要耗电部件分析空调主要耗电阶段分析启动阶段:瞬时功率较高,但持续时间较短制冷运行阶段:压缩机持续运行,耗电量最大温度达到设定值后的阶段&#x…

深入探索非线性数据结构:树与图的世界

在数据结构的广阔天地中,非线性结构以其独特的逻辑关系和广泛的应用场景,成为计算机科学领域的重要组成部分。其中,树和图作为两种典型的非线性数据结构,不仅深刻影响了算法的设计与分析,也广泛应用于各种实际问题的解…

基于tkinter的学生信息管理系统之登录界面和主界面菜单设计

目录 一、tkinter的介绍 二、登陆界面的设计 1、登陆界面完整代码 2、部分代码讲解 3、登录的数据模型设计 4、效果展示 三、学生主界面菜单设计 1、学生主界面菜单设计完整代码 2、 部分代码讲解 3、效果展示 四、数据库的模型设计 欢迎大家进来学习和支持&#xff01…

灯具外贸公司用什么企业邮箱好

灯具外贸公司面对海外市场的推广、产品销售、客户沟通、市场信息收集等多重需求,选择一个合适的企业邮箱显得尤为重要。本文将介绍灯具外贸公司为什么应选择Zoho Mail企业邮箱,并详细探讨其优势和功能。 一、公司背景 广东省深圳市光明新区&#xff0c…

持久化存储:Mojo模型中模型保存与加载的艺术

持久化存储:Mojo模型中模型保存与加载的艺术 在机器学习项目中,模型的持久化存储是一个关键环节,它允许我们将训练好的模型保存下来,并在需要时重新加载使用。Mojo模型,作为一个虚构的高级机器学习框架,支…

Redis 安装和数据类型

Redis 安装和数据类型 一、Redis 1、Redis概念 redis 缓存中间件:缓存数据库 nginx web服务 php 转发动态请求 tomcat web页面,也可以转发动态请求 springboot 自带tomcat 数据库不支持高并发,一旦访问量激增,数据库很快就…

vTESTstudio中如何添加DLL文件?

文章目录 一、CANoe添加DLL二、vTESTstudio中添加DLL1.手动添加2.代码添加 一、CANoe添加DLL 在CANoe中添加DLL的路径如下图,在Simulation Setup中选择需要添加的节点,右键选择Configuration进行添加DLL。 二、vTESTstudio中添加DLL 1.手动添加 在打…

java中 VO DTO BO PO DAO

VO、DTO、BO、PO、DO、POJO 数据模型的理解和实际使用_vo dto bo-CSDN博客 深入理解Java Web开发中的PO、VO、DTO、DAO和BO概念_java dto dao-CSDN博客

【计算机网络】WireShark和简单http抓包实验

一:实验目的 1:熟悉WireShark的安装流程和界面操作流程。 2:学会简单http的抓取和过滤,并分析导出结果。 二:实验仪器设备及软件 硬件: Windows 2019操作系统的计算机等。 软件:WireShark、…

【算法/训练】:动态规划(线性DP)

一、路径类 1. 字母收集 思路: 1、预处理 对输入的字符矩阵我们按照要求将其转换为数字分数,由于只能往下和往右走,因此走到(i,j)的位置要就是从(i - 1, j)往下走&#…

vector清空

https://www.zhihu.com/question/592055868/answer/2967078686

java使用hutool工具检查远程端口是否开启

使用java校验ip地址或域名的端口是否开启 1.导入hutool工具的maven依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency>2.复制一下代码案例直接运行 …

前端面试基础题(微信公众号:前端面试成长之路)

BFC、IFC、GFC、FFC CSS2.1中只有BFC和IFC, CSS3中才有GFC和FFC。 到底什么是BFC、IFC、GFC和FFC Whats FC&#xff1f; 一定不是KFC&#xff0c;FC的全称是&#xff1a;Formatting Contexts&#xff0c;是W3C CSS2.1规范中的一个概念。它是页面中的一块渲染区域&#xff0c;并…

量度卓越:Mojo模型中自定义评估与模型比较的艺术

量度卓越&#xff1a;Mojo模型中自定义评估与模型比较的艺术 在机器学习项目中&#xff0c;模型评估是衡量算法性能的关键步骤。Mojo模型&#xff0c;作为一个先进的机器学习框架&#xff0c;提供了丰富的工具来支持模型评估和比较。本文将深入探讨如何在Mojo模型中实现自定义…

openj9-17.0.2_8-jre-alpine 和 openjdk:17-alpine 的区别是什么?

openj9-17.0.2_8-jre-alpine 和 openjdk:17-alpine 都是用于运行 Java 应用程序的 Docker 镜像&#xff0c;但它们之间有一些关键的区别&#xff1a; JVM Implementation: openj9-17.0.2_8-jre-alpine 使用的是 Eclipse OpenJ9&#xff0c;这是一种高效、低内存消耗的 JVM 实现…

go-sql-driver/mysql 查询 latin1 中文字符集

select name from table; table是 latin1 编码&#xff0c; 返回后查询结果后&#xff0c;即使将 name 转为 utf-8&#xff0c;日志输出中文仍然乱码。 // 配置数据库连接字符串&#xff0c;确保指定charsetlatin1dsn : "user:passwordtcp(127.0.0.1:3306)/dbname?chars…

免费【2024】springboot 宠物领养救助平台的开发与设计

博主介绍&#xff1a;✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流✌ 技术范围&#xff1a;SpringBoot、Vue、SSM、HTML、Jsp、PHP、Nodejs、Python、爬虫、数据可视化…

每日一练,java07

目录 题目1.请问运行主要的程序会打印出的是什么&#xff08;&#xff09;2.下面论述正确的是&#xff08;&#xff09;&#xff1f;3.下面哪些Java中的流对象是字节流?4.关于以下代码的说明&#xff0c;正确的是&#xff08; &#xff09;5.若需要定义一个类&#xff0c;下列…

普元EOS学习笔记-EOS项目HTTP访问安全和权限控制

前言 对于企业应用系统&#xff0c;出于安全和权限控制的目的&#xff0c;需要对http请求做若干控制。 比如文件上传的时候要控制不允许上传的文件后缀。 又比如控制应用程序中的哪些资源不允许被访问。 EOS项目通过 xml配置文件来实现这一需求。 Http访问管理模块 在EOS项…