问题描述
最近在跑一个基于pytorch的强化学习代码,在训练过程中显存增大非常明显,迭代不到200个iteration就可以占据70G+的显存。由于博主是第一次在pytorch实现的强化学习算法上加入自己的实现,很没有应对经验,现将调试过程记录下来供有同样问题的人参考。
解决方案
1 通过逐行注释观察显存变化来确定到底是哪里出现了泄露
这个思想来自于[1],原答主是这样回答的:
看上去很简单,但是非常有效,博主的bug就是这样找到泄露的地方的。找到了泄露的地方之后,就可以去查找对应的解决方案了(问度娘、看issues,balabala)。
2 查看常用错误对号入座
[2][3][4]记录了许多有关显存泄露的普遍错误,多与深度学习相关,感兴趣的读者可以看看,了解其中的原理,查看一下自己的代码有没有类似的错误。
3 使用显存使用展示工具memory_profiler
memory_profiler的下载链接如下:
memory-profiler · PyPI
可以使用以下指令进行安装:
pip install -U memory_profiler
在想进行分析的函数前面加上@profile,再运行
python -m memory_profiler example.py
即可打印出各行代码的内存占用结果。
附一个官方例子:
参考链接
[Debug记录] | Pytorch训练网络时出现内存泄漏 - 知乎
torch代码运行时显存溢出问题 - 简书
Tensor是如何让你的内存/显存泄漏的 - 知乎
PyTorch显存分析 - 知乎