我的PyTorch模型比内存还大,怎么训练呀?

原文:我的PyTorch模型比内存还大,怎么训练呀? - 知乎

看了一篇比较老(21年4月文章)的不大可能训练优化方案,保存起来以后研究一下。

我的PyTorch模型比内存还大,怎么训练呀?

随着深度学习的飞速发展,模型越来越臃肿,哦不,先进,运行SOTA模型的主要困难之一就是怎么把它塞到 GPU 上,毕竟,你无法训练一个设备装不下的模型。改善这个问题的技术有很多种,例如,分布式训练和混合精度训练。

本文将介绍另一种技术: 梯度检查点(gradient checkpointing)简单的说,梯度检查点的工作原理是在反向时重新计算深层神经网络的中间值(而通常情况是在前向时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)。

文末有一个示例基准测试,它显示了梯度检查点减少了模型 60% 的内存开销(以增加 25% 的训练时间为代价)。

详细代码请查看我的 GitHub 库: https://github.com/spellml/tweet-sentiment-extraction/blob/master/notebooks/5-checkpointing.ipynb

>>> 神经网络如何使用内存

为了理解梯度检查点是如何起作用的,我们首先需要了解一下模型内存分配是如何工作的。

神经网络使用的总内存基本上是两个部分的和。

第一部分是模型使用的静态内存。尽管 PyTorch 模型中内置了一些固定开销,但总的来说几乎完全由模型权重决定。当今生产中使用的现代深度学习模型的总参数在100万到10亿之间。作为参考,一个带 16GB GPU 内存的 NVIDIA T4 的实际限制大约在1-1.5亿个参数之间。

第二部分是模型的计算图所占用的动态内存。在训练模式下,每次通过神经网络的前向传播都为网络中的每个神经元计算一个激活值,这个值随后被存储在所谓的计算图中。必须为批中的每个单个训练样本存储一个值,因此数量会迅速的累积起来。总开销由模型大小和批次大小决定,一般设置最大批次大小限制来适配你的 GPU 内存。

要了解更多关于 PyTorch autograd 的信息,请查看我的 Kaggle 笔记本《PyTorch autograd 解释》: https://www.kaggle.com/residentmario/pytorch-autograd-explained

>>> 梯度检查点是如何起作用的

大型模型在静态和动态方面都很耗资源。首先,它们很难适配 GPU,而且哪怕你把它们放到了设备上,也很难训练,因为批次大小被迫限制的太小而无法收敛。

现有的各种技术可以改善这些问题中的一个或两个。梯度检查点就是这样一种技术; 分布式训练,是另一种技术。

梯度检查点(gradient checkpointing) 的工作原理是从计算图中省略一些激活值。这减少了计算图使用的内存,降低了总体内存压力(并允许在处理过程中使用更大的批次大小)。

但是,一开始存储激活的原因是,在反向传播期间计算梯度时需要用到激活。在计算图中忽略它们将迫使 PyTorch 在任何出现这些值的地方重新计算,从而降低了整体计算速度。

因此,梯度检查点是计算机科学中折衷的一个经典例子,即在内存和计算之间的权衡。

PyTorch 通过 torch.utils.checkpoint.checkpoint 和 torch.utils.checkpoint.checkpoint_sequential 提供梯度检查点,根据官方文档的 notes,它实现了如下功能,在前向传播时,PyTorch 将保存模型中的每个函数的输入元组。在反向传播过程中,对于每个函数,输入元组和函数的组合以实时的方式重新计算,插入到每个需要它的函数的梯度公式中,然后丢弃。网络计算开销大致相当于每个样本通过模型前向传播开销的两倍。

梯度检查点首次发表在2016年的论文 《Training Deep Nets With Sublinear Memory Cost》 中。论文声称提出的梯度检查点算法将模型的动态内存开销从 O(n)n 为模型中的层数)降低到 O(sqrt(n)),并通过实验展示了将 ImageNet 的一个变种从 48GB 压缩到了 7GB 内存占用。

>>> 测试 API

PyTorch API 中有两个不同的梯度检查点方法,都在 torch.utils.checkpoint 命名空间中。两者中比较简单的一个是 checkpoint_sequential,它被限制用于顺序模型(例如使用 torch.nn.Sequential wrapper 的模型)。另一个是更灵活的 checkpoint,可以用于任何模块。

下面是一个完整的代码示例,显示了 checkpoint_sequential 的实际用法:

import torch
import torch.nn as nnfrom torch.utils.checkpoint import checkpoint_sequential# a trivial model
model = nn.Sequential(nn.Linear(100, 50),nn.ReLU(),nn.Linear(50, 20),nn.ReLU(),nn.Linear(20, 5),nn.ReLU()
)# model input
input_var = torch.randn(1, 100, requires_grad=True)# the number of segments to divide the model into
segments = 2# finally, apply checkpointing to the model
# note the code that this replaces:
# out = model(input_var)
out = checkpoint_sequential(modules, segments, input_var)# backpropagate
out.sum().backwards()

如你所见,checkpoint_sequential 替换了 module 对象上的 forward 或 __call__ 方法。out 几乎和我们调用 model(input_var) 时得到的张量一样; 关键的区别在于它缺少了累积值,并且附加了一些额外的元数据,指示 PyTorch 在 out.backward() 期间需要这些值时重新计算。

值得注意的是,checkpoint_sequential 接受整数值的片段数作为输入。checkpoint_sequential 将模型分割成 n 个纵向片段,并对除了最后一个的每个片段应用检查点。

这工作很容易,但有一些主要的限制。你无法控制片段的边界在哪里,也无法对整个模块应用检查点(而是其中的一部分)。

替代方法是使用更灵活的 checkpoint API. 下面展示了一个简单的卷积模型:

class CIFAR10Model(nn.Module):def __init__(self):super().__init__()self.cnn_block_1 = nn.Sequential(*[nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.cnn_block_2 = nn.Sequential(*[nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)])self.flatten = lambda inp: torch.flatten(inp, 1)self.head = nn.Sequential(*[nn.Linear(64 * 8 * 8, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 10)])def forward(self, X):X = self.cnn_block_1(X)X = self.cnn_block_2(X)X = self.flatten(X)X = self.head(X)return X

这种模型有两个卷积块,一些 dropout,和一个线性头(10个输出对应 CIFAR10 的10类)。

下面是这个模型使用梯度检查点的更新版本:

class CIFAR10Model(nn.Module):def __init__(self):super().__init__()self.cnn_block_1 = nn.Sequential(*[nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2)])self.dropout_1 = nn.Dropout(0.25)self.cnn_block_2 = nn.Sequential(*[nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2)])self.dropout_2 = nn.Dropout(0.25)self.flatten = lambda inp: torch.flatten(inp, 1)self.linearize = nn.Sequential(*[nn.Linear(64 * 8 * 8, 512),nn.ReLU()])self.dropout_3 = nn.Dropout(0.5)self.out = nn.Linear(512, 10)def forward(self, X):X = self.cnn_block_1(X)X = self.dropout_1(X)X = checkpoint(self.cnn_block_2, X)X = self.dropout_2(X)X = self.flatten(X)X = self.linearize(X)X = self.dropout_3(X)X = self.out(X)return X

在 forward 中显示的 checkpoint 接受一个模块(或任何可调用的模块,如函数)及其参数作为输入。参数将在前向时被保存,然后用于在反向时重新计算其输出值。

为了使其能够工作,我们必须对模型定义进行一些额外的更改。

首先,你会注意到我们从卷积块里删除了 nn.Dropout 层; 这是因为检查点与 dropout 不兼容(回想一下,样本有效地通过模型两次 —— dropout 会在每次通过时任意丢失不同的值,从而产生不同的输出)。基本上,任何在重新运行时表现出非幂等(non-idempotent )行为的层都不应该应用检查点(nn.BatchNorm 是另一个例子)。解决方案是重构模块,这样问题层就不会被排除在检查点片段之外,这正是我们在这里所做的。

其次,你会注意到我们在模型中的第二卷积块上使用了检查点,但是第一个卷积块上没有使用检查点。这是因为检查点简单地通过检查输入张量的 requires_grad 行为来决定它的输入函数是否需要梯度下降(例如,它是否处于 requires_grad=True 或 requires_grad=False模式)。模型的输入张量几乎总是处于 requires_grad=False 模式,因为我们感兴趣的是计算相对于网络权重而不是输入样本本身的梯度。因此,模型中的第一个子模块应用检查点没多少意义: 它反而会冻结现有的权重,阻止它们进行任何训练。更多细节请参考这个 PyTorch 论坛帖子:https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271

在 PyTorch 文档(https://pytorch.org/docs/stable/checkpoint.html#)中还讨论了 RNG 状态以及与分离张量不兼容的一些其他细节。

完整的训练代码示例可以看这里: https://gist.github.com/ResidentMario/e3254172b4706191089bb63ecd610e21

和这里: https://gist.github.com/ResidentMario/9c3a90504d1a027aab926fd65ae08139

>>> 基准测试

作为一个快速的基准测试,我在 tweet-sentiment-extraction 上启用了模型检查点,这是一个基于 Twitter 数据的带有 BERT 主干的情感分类器模型。你可以在这里看到代码:https://github.com/spellml/tweet-sentiment-extraction。transformers 已经将模型检查点作为 API 的一个可选部分来实现; 为我们的模型启用它就像翻转一个布尔值标记一样简单:

# code from model_5.pycfg = transformers.PretrainedConfig.get_config_dict("bert-base-uncased")[0]
cfg["output_hidden_states"] = True
cfg["gradient_checkpointing"] = True  # NEW!
cfg = transformers.BertConfig.from_dict(cfg)
self.bert = transformers.BertModel.from_pretrained("bert-base-uncased", config=cfg
)

我对这个模型进行了四次训练: 分别在 NVIDIA T4和 NVIDIA V100 GPU 上,包括检查点和无检查点模式。所有运行的批次大小为 64。以下是结果:

第一行是在模型检查点关闭的情况下进行的训练,第二行是在模型检查点开启的情况下进行的训练。

模型检查点降低了峰值模型内存使用量 60% ,同时增加了模型训练时间 25% 。

当然,你想要使用检查点的主要原因可能是,这样你就可以在 GPU 上使用更大的批次大小。在另一篇博文:https://qywu.github.io/2019/05/22/explore-gradient-checkpointing.html 中演示了这个很好的例子: 在他们的例子中,每批次样本从 24 个提高到惊人的 132 个!

要处理大型神经网络,模型检查点显然是一个非常强大和有用的工具。

原文: https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs

发布于 2021-04-27 22:39

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/672143.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue element 组件 form深层 :prop 验证失效问题解决

此图源自官网 借鉴。 当我们简单单层验证的时候发现是没有问题的,但是有的时候可能会涉及到深层prop,发现在去绑定的时候就不生效了。例如我们在form单里面循环验证,在去循环数据验证。 就如下图的写法了 :prop"pumplist. i .device…

Redis缓存设计及优化

缓存设计 缓存穿透 缓存穿透是指查询一个根本不存在的数据, 缓存层和存储层都不会命中, 通常出于容错的考虑, 如果从存储层查不到数据则不写入缓存层。 缓存穿透将导致不存在的数据每次请求都要到存储层去查询, 失去了缓存保护后…

Pandas 对带有 Multi-column(多列名称) 的数据排序并写入 Excel 中

Pandas 从Excel 中读取带有 Multi-column的数据 正文 正文 我们使用如下方式写入数据: import pandas as pd import numpy as npdf pd.DataFrame(np.array([[10, 2, 0], [6, 1, 3], [8, 10, 7], [1, 3, 7]]), columns[[Number, Name, Name, ], [col 1, col 2, co…

数据结构——C/栈和队列

🌈个人主页:慢了半拍 🔥 创作专栏:《史上最强算法分析》 | 《无味生》 |《史上最强C语言讲解》 | 《史上最强C练习解析》 🏆我的格言:一切只是时间问题。 ​ 1.栈 1.1栈的概念及结构 栈:一种特…

WPF是不是垂垂老矣啦?平替它的框架还有哪些

WPF(Windows Presentation Foundation)是微软推出的一种用于创建 Windows 应用程序的用户界面框架。WPF最初是在2006年11月推出的,它是.NET Framework 3.0的一部分,为开发人员提供了一种基于 XAML 的方式来构建丰富的用户界面。 W…

你的代码很丑吗?试试这款高颜值代码字体

Monaspace 是有 GitHub 开源的代码字体,包含 5 种变形字体的等宽代码字体家族,颜值 Up,很难不喜欢。 来看一下这 5 种字体分别是: 1️⃣ Radon 手写风格字体 2️⃣ Krypton 机械风格字体 3️⃣ Xenon 衬线风格字体 4️⃣ Argon…

【C++二维前缀和】黑格覆盖

题目描述 在一张由 M * N 个小正方形格子组成的矩形纸张上,有 k 个格子被涂成了黑色。给你一张由 m * n 个同样小正方形组成的矩形卡片,请问该卡片最多能一次性覆盖多少个黑格子? 输入 输入共 k1 行: 第 1 行为 5 个整数 M、N、…

【ES数据可视化】kibana实现数据大屏

目录 1.概述 2.绘制数据大屏 2.1.准备数据 2.2.绘制大屏 3.嵌入项目中 1.概述 再来重新认识一下kibana: Kibana 是一个用于数据可视化和分析的开源工具,是 Elastic Stack(以前称为 ELK Stack)中的一部分,由 Ela…

推理系统学习笔记

一些学习资料 最近对MLsys比较感兴趣,遂找些资料开始学习一下 https://fazzie-key.cool/2023/02/21/MLsys/https://qiankunli.github.io/2023/12/16/llm_inference.htmlhttps://dlsyscourse.orghttps://github.com/chenzomi12/DeepLearningSystem/tree/main/04Infe…

3、生成式 AI 如何帮助您改进数据可视化图表

生成式 AI 如何帮助您改进数据可视化图表 使用生成式 AI 加速和增强数据可视化。 图像来源:DALLE 3 5 个关键要点: 数据可视化图表的基本结构使用 Python Altair 构建数据可视化图表使用 GitHub Copilot 加快图表生成速度使用 ChatGPT 为您的图表生成相关内容使用 DALL-E 将…

[BUUCTF]-PWN:wustctf2020_easyfast解析

又是堆题,查看保护 再看ida 大致就是alloc创建堆块,free释放堆块,fill填充堆块,以及一个getshell的函数,但要满足条件。 值得注意的是free函数没有清空堆块指针 所以可以用double free 有两种解法 解法一&#xff0…

【Linux】打包压缩跨系统/网络传输文件常用指令完结

Hello everybody!在今天的文章中我会把剩下的3-4个常用指令讲完,然后开始权限的讲解。那废话不多说,咱们直接进入正题! 1.zip/unzip&tar命令 1.zip/unzip 在windows系统中,经常见到带有zip后缀的文件。那个东西就是压缩包。…

杭州融资融券利率一般最低是4.5%,两融有哪些核心注意事项?

融资融券利率行情 使用融资融券账户的投资者越来越多,对于准备开两融和想换两融券商的投资者来说,最关心的就是两融利率以及开两融或者换券商的便捷程度了。 目前市场上最低的融资融券利率是4.5%~5%,普遍的两融利率一般在5%-6.5%&#xff0…

龙芯安装使用搜狗输入法

CPU:龙芯3A6000 操作系统:Loongnix 桌面主题:Cartoon 龙芯系统切换输入法的按键一般为:Ctrl空格。 1 安装搜狗输入法 进入Loongnix系统自带的龙芯应用合作社,寻找搜狗输入法,点击安装。 按下Ctrl空格&…

计算机网络——网络

计算机网络——网络 小程一言专栏链接: [link](http://t.csdnimg.cn/ZUTXU)前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家, [跳转到网站](https://www.captainbed.cn/qianqiu) 无线网络和移动网…

用HTML5 + JavaScript实现下雪效果

用HTML5 JavaScript实现下雪效果 <canvas>是一个可以使用脚本 (通常为JavaScript) 来绘制图形的 HTML 元素。 <canvas> 标签/元素只是图形容器&#xff0c;必须使用脚本来绘制图形。 HTML5 canvas 图形标签基础https://blog.csdn.net/cnds123/article/details/…

ArcGIS的UTM与高斯-克吕格投影分带要点总结

UTM&#xff08;通用横轴墨卡托投影、等角横轴割椭圆柱投影&#xff09;投影分带投影要点&#xff1a; 1&#xff09;UTM投影采用6度分带 2&#xff09;可根据公式计算&#xff0c;带数&#xff08;经度整数位/6&#xff09;的整数部分31 3&#xff09;北半球地区&#xff0…

蓝桥杯Web应用开发-CSS3 新特性

CSS3 新特性 专栏持续更新中 在前面我们已经学习了元素选择器、id 选择器和类选择器&#xff0c;我们可以通过标签名、id 名、类名给指定元素设置样式。 现在我们继续选择器之旅&#xff0c;学习 CSS3 中新增的三类选择器&#xff0c;分别是&#xff1a; • 属性选择器 • 子…

STM32搭建开发环境

常用开发工具简介 集成开发环境 MDK&#xff1a;全名RealViewMDK&#xff0c;是Keil公司&#xff08;已被ARM收购的&#xff09;一款集成开发环境&#xff0c;界面美观&#xff0c;简单易用&#xff0c;是STM32最常用的集成开发环境EWARM&#xff1a;IAR公司的一款集成开发环…

洛谷_P1464 Function_python写法

目录 1.错误解法 2.学习记忆化搜索算法 2.1简介 2.2案例学习 3.解法 4.总结 1.错误解法 a 0 b 0 c 0 def w(a,b,c):if a<0 or b<0 or c<0:return 1elif a>20 or b>20 or c>20:return w(20,20,20)elif a<b and b<c:return w(a-1,b,c) w(a-1,…