[内存泄漏][PyTorch](create_graph=True)

PyTorch保存计算图导致内存泄漏

  • 1. 内存泄漏定义
  • 2. 问题发现背景
  • 3. github中pytorch源码关于这个问题的讨论

1. 内存泄漏定义

  内存泄漏(Memory Leak)是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。

2. 问题发现背景

  在使用深度学习求解PDE时,由于经常需要计算高阶导数,在pytorch框架下写的代码需要用到torch.autograd.grad(create_graph=True)或者torch.backward(create_graph=True)这个参数,然后发现了这个内存泄漏的问题。如果要保存计算图用来计算高阶导数,那么其所占的内存不会被释放,会一直占用。也就是如果设置create_graph=True,那么其保存的计算图所占的内存只有在程序运行结束时才会释放,这样导致了一个问题,即如果在循环中需要保存计算图,例如每个循环都需要计算一次黑塞矩阵,那么这个内存占用就会越来越多,最终导致out of memory报错。
在这里插入图片描述

3. github中pytorch源码关于这个问题的讨论

  官网中关于这个问题的讨论见https://github.com/pytorch/pytorch/issues/7343,这里提出的内存泄漏的例子如下:

import torch
import gc_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
y.backward(torch.ones_like(y), create_graph=True)
del x, y
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到虽然删除了变量,依然造成了内存泄漏。这里红色的警告就是关于这个内存泄漏的问题。

UserWarning: Using backward() with create_graph=True will create a reference cycle between
the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad 
when creating the graph to avoid this. If you have to use this function, make sure to reset 
the .grad fields of your parameters to None after use to break the cycle and avoid the leak. 
(Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\engine.cpp:1000.)
allow_unreachable=True, accumulate_grad=True) 
# Calls into the C++ engine to run the backward pass

看这个UserWarning,提示我们使用torch.autograd.grad()函数可以避免这个梯度泄漏,然后对代码进行改动:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
# y.backward(torch.ones_like(y), create_graph=True)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果显示没有梯度泄漏。进一步,我们求一下二阶导数:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果也没有内存泄漏。但是,如果我们不删除结果二阶导数q,这样是出于如果写在一个函数中,需要将q作为return值返回的情况。

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到,这还是会导致一部分内存泄漏。这里记录一下这个问题,有读者有遇到相同问题欢迎讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/148960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2021集创赛】 RISC-V杯三等奖:基于E203 处理器的SM4算法硬件加速

杯赛题目:基于蜂鸟E203 RISC-V处理器内核的SoC设计 参赛要求:研究生组/本科生组 赛题内容: 基于芯来科技的开源蜂鸟E203 Demo SoC进行扩展,在限定的可编程逻辑平台上构建面向专用应用领域(譬如人工智能、信息安全、工业…

手机,蓝牙开发板,TTL/USB模块,电脑四者之间的通讯

一,意图 通过手机蓝牙连接WeMosD1R32开发板,开发板又通过TTL转USB与电脑连接.手机通过蓝牙控制开发板上的LED灯的开,关,闪等动作,在电脑上打开串口监视工具观察其状态.也可以通过电脑上的串口监视工具来控制开发板上LED灯的动作,而在手机蓝牙监测工具中显示灯的状态. 二,原料…

美团外卖9元每周星期一开工外卖红包优惠券怎么领取?

美团外卖9元周一开工红包活动时间是什么时候? 美团外卖9元周一开工红包优惠券是指每周星期一可以领取的美团外卖红包优惠券,在美团外卖周一开工红包领取活动时间内可领取到9元周一开工美团外卖红包优惠券;(温馨提醒:如…

【ArcGIS】批量对栅格图像按要素掩膜提取

要把一张大的栅格图裁成分省或者分县市的栅格集,一般是用ArcGIS里的按掩膜提取。 但是有的时候所要求的栅格集量非常大,所以用代码来做批量掩膜(按字段)会非常方便。 import arcpy , shutil , os from arcpy import env from ar…

PHP常用的数组函数

PHP是一种流行的服务器端脚本语言,广泛用于Web开发。数组是PHP中最重要且最常用的数据类型之一,它提供了许多强大的数组函数,用于在数组上执行各种操作。在本文中,我们将深入解析PHP中一些常用的数组函数,以便更好地理…

【C/C++笔试练习】继承和派生的概念、虚函数的概念、派生类的析构函数、纯虚函数的概念、动态编译、多态的实现、参数解析、跳石板

文章目录 C/C笔试练习选择部分(1)继承和派生的概念(2)程序分析(3)虚函数的概念(4)派生类的析构函数(5)纯虚函数的概念(6)动态编译&…

uniapp App 端 版本更新检测

function checkVersion() { var req { //升级检测数据 appid: plus.runtime.appid, version: plus.runtime.version }; const timestamp Date.parse(new Date()); config.server.query_news uni.reque…

LangChain 2模块化prompt template并用streamlit生成网站 实现给动物取名字

上一节实现了 LangChain 实现给动物取名字, 实际上每次给不同的动物取名字,还得修改源代码,这周就用模块化template来实现。 1. 添加promptTemplate from langchain.llms import OpenAI # 导入Langchain库中的OpenAI模块 from langchain.p…

优思学院|什么是精益生产管理?从一个生活上的故事出发来说明。

你关掉电脑,离开办公室。 一个小时后,你进入家门和孩子们在一起。 你和家人一起吃晚饭。 你的老板打电话来查看你的项目进展。 你哄孩子入睡并给他们读个故事。 作为一个负责任的父母,你想要与孩子们的互动时间增加并提高生活的质量&…

ChatGPT + DALL·E 3

参考链接: https://chat.xutongbao.top/

Linux中安装部署环境(JAVA)

目录 在Linux中安装jdk 包管理器yum安装jdk JDK安装过程中的问题 验证安装jdk 在Linux中安装tomcat 安装mysql 在Linux中安装jdk jdk在Linux中的安装方式有很多种, 这里介绍最简单的方法, 也就是包管理器方法: 包管理器yum安装jdk Linux中常见的包管理器有: yumaptp…

图论| 827. 最大人工岛 127. 单词接龙

827. 最大人工岛 题目:给你一个大小为 n x n 二进制矩阵 grid 。最多 只能将一格 0 变成 1 。返回执行此操作后,grid 中最大的岛屿面积是多少? 岛屿 由一组上、下、左、右四个方向相连的 1 形成。 题目链接:[827. 最大人工岛](ht…

前端为什么要工程化

前端为什么要工程化 文章目录 前端为什么要工程化传统开发的弊端一个常见的案例更多问题 工程化带来的优势开发层面的优势团队协作的优势统一的项目结构统一的代码风格可复用的模块和组件代码健壮性有保障团队开发效率高 求职竞争上的优势 现在前端的工作与以前的前端开发已经完…

深度学习交通车辆流量分析 - 目标检测与跟踪 - python opencv 计算机竞赛

文章目录 0 前言1 课题背景2 实现效果3 DeepSORT车辆跟踪3.1 Deep SORT多目标跟踪算法3.2 算法流程 4 YOLOV5算法4.1 网络架构图4.2 输入端4.3 基准网络4.4 Neck网络4.5 Head输出层 5 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 *…

Linux 时区设置

对于服务器来说,linux的时区影响着运行之上的数据库和后端程序的时区 应该和数据库和后端及其他程序的时区保持一致 其他相关时区的设置 pgsql时区设置: php时区设置: 1.显示当前的时间和时区 date结果类似下面,图中显示的是ut…

球幕投影有哪些常见的物理表现形式?

近年来,投影技术不断发展完善,给内容的表达方式带来了突破,使其展示形式不再局限于平面,即使在弧面、球面等异形幕墙上,也能呈现出令人惊叹的视觉画面。其中球幕投影备受关注,它以半球形屏幕将图像投影到球…

Selenium安装WebDriver(含116/117/118/119)

1、确认浏览器的版本 在浏览器的地址栏,输入chrome://version/,回车后即可查看到对应版本 2、找到对应的chromedriver版本 2.1 114及之前的版本可以通过点击下载chromedriver,根据版本号(只看大版本)下载对应文件 2.2 116版…

解决 VS2022 关于 c++17 报错: C2131 表达式必须含有常量值

使用 VS2022 编译 ORB-SLAM3 加载Vocabulary 二进制ORBvoc.bin 时,在 DBOW2 里修改 TemplatedVocabulary.h 代码显示这样的错误: 编译器错误 C2131 表达式的计算结果不是常数 定位到我的代码中: char buf [size_node] ; 原因 : …

PyTorch - 高效快速配置 Conda + PyTorch 环境 (解决 segment fault )

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/134463035 在配置算法项目时,因网络下载速度的原因,导致默认的 conda 与 pytorch 包安装缓慢,需要配置新的 co…