pytorch 优化训练显存方式

转载自(侵删):

 https://www.cnblogs.com/chentiao/p/17901639.html 

 

1. 显存都用在哪儿了?

一般在训练神经网络时,显存主要被网络模型中间变量占用。

  • 网络模型中的卷积层,全连接层和标准化层等的参数占用显存,而诸如激活层和池化层等本质上是不占用显存的。
  • 中间变量包括特征图和优化器等,是消耗显存最多的部分。
  • 其实 pytorch 本身也占用一些显存的,但占用不多,以下方法大致按照推荐的优先顺序。

2. 技巧 1:使用就地操作

就地操作 (inplace) 字面理解就是在原地对变量进行操作,对应到 pytorch 中就是在原内存上对变量进行操作而不申请新的内存空间,从而减少对内存的使用。具体来说就地操作包括三个方面的实现途径:

  • 使用将 inplace 属性定义为 True 的激活函数,如 nn.ReLU(inplace=True)
  • 使用 pytorch 带有就地操作的方法,一般是方法名后跟一个下划线 “_”,如 tensor.add_()tensor.scatter_()F.relu_()
  • 使用就地操作的运算符,如 y += xy *= x

3. 技巧 2:避免中间变量

在自定义网络结构的成员方法 forward 函数里,避免使用不必要的中间变量,尽量在之前已申请的内存里进行操作,比如下面的代码就使用太多中间变量,占用大量不必要的显存:

def forward(self, x):x0 = self.conv0(x)  # 输入层x1 = F.relu_(self.conv1(x0) + x0)x2 = F.relu_(self.conv2(x1) + x1)x3 = F.relu_(self.conv3(x2) + x2)x4 = F.relu_(self.conv4(x3) + x3)x5 = F.relu_(self.conv5(x4) + x4)x6 = self.conv(x5)  # 输出层return x6

为了减少显存占用,可以将上述 forward 函数修改如下:

def forward(self, x):x = self.conv0(x)  # 输入层x = F.relu_(self.conv1(x) + x)x = F.relu_(self.conv2(x) + x)x = F.relu_(self.conv3(x) + x)x = F.relu_(self.conv4(x) + x)x = F.relu_(self.conv5(x) + x)x = self.conv(x)  # 输出层return x上述两段代码实现的功能是一样的,但对显存的占用却相去甚远,后者能节省前者占用显存的接近 90% 之多。

4. 技巧 3:优化网络模型

网络模型对显存的占用主要指的就是卷积层,全连接层和标准化层等的参数,具体优化途径包括但不限于:

  • 减少卷积核数量 (=减少输出特征图通道数)
  • 不使用全连接层
  • 全局池化 nn.AdaptiveAvgPool2d() 代替全连接层 nn.Linear()
  • 不使用标准化层
  • 跳跃连接跨度不要太大太多 (避免产生大量中间变量)

5. 技巧 4:减小 BATCH_SIZE

  • 在训练卷积神经网络时,epoch 代表的是数据整体进行训练的次数,batch 代表将一个 epoch 拆分为 batch_size 批来参与训练。
  • 减小 batch_size 是一个减小显存占用的惯用技巧,在训练时显存不够一般优先减小 batch_size ,但 batch_size 不能无限变小,太大会导致网络不稳定,太小会导致网络不收敛。

6. 技巧 5:拆分 BATCH

拆分 batch 跟技巧 4 中减小 batch_size 本质是不一样的, 这种拆分 batch 的操作可以理解为将两次训练的损失相加再反向传播,但减小 batch_size 的操作是训练一次反向传播一次。拆分 batch 操作可以理解为三个步骤,假设原来 batch 的大小 batch_size=64

  • 将 batch 拆分为两个 batch_size=32 的小 batch
  • 分别输入网络与目标值计算损失,将得到的损失相加
  • 进行反向传播

7. 技巧 6:降低 PATCH_SIZE

  • 在卷积神经网络训练中,patch_size 指的是输入神经网络的图像大小,即(H*W)。
  • 网络输入 patch 的大小对于后续特征图的大小等影响非常大,训练时可能采用诸如 [64*64],[128*128] 等大小的 patch,如果显存不足可以进一步缩小 patch 的大小,比如 [32*32],[16*16]。
  • 但这种方法存在问题,可能极大地影响网络的泛化能力,在裁剪的时候一定要注意在原图上随机裁剪,一般不建议。

8. 技巧 7:优化损失求和

一个 batch 训练结束会得到相应的一个损失值,如果要计算一个 epoch 的损失就需要累加之前产生的所有 batch 损失,但之前的 batch 损失在 GPU 中占用显存,直接累加得到的 epoch 损失也会在 GPU 中占用显存,可以通过如下方法进行优化:

epoch_loss += batch_loss.detach().item()  # epoch 损失

上边代码的效果就是首先解除 batch_loss 张量的 GPU 占用,将张量中的数据取出再进行累加。

9. 技巧 8:调整训练精度

  • 降低训练精度
    pytorch 中训练神经网络时浮点数默认使用 32 位浮点型数据,在训练对于精度要求不是很高的网络时可以改为 16 位浮点型数据进行训练,但要注意同时将数据和网络模型都转为 16 位浮点型数据,否则会报错。降低浮点型数据的操作实现过程非常简单,但如果优化器选择 Adam 时可能会报错,选择 SGD 优化器则不会报错,具体操作步骤如下:
model.cuda().half()  # 网络模型设置半精度
# 网络输入和目标设置半精度
x, y = Variable(x).cuda().half(), Variable(y).cuda().half()

  • 混合精度训练
    混合精度训练指的是用 GPU 训练网络时,相关数据在内存中用半精度做储存和乘法来加速计算,用全精度进行累加避免舍入误差,这种混合经度训练的方法可以令训练时间减少一半左右,也可以很大程度上减小显存占用。在 pytorch1.6 之前多使用 NVIDIA 提供的 apex 库进行训练,之后多使用 pytorch 自带的 amp 库,实例代码如下:
import torch
from torch.nn.functional import mse_loss
from torch.cuda.amp import autocast, GradScalerEPOCH = 10  # 训练次数
LEARNING_RATE = 1e-3  # 学习率x, y = torch.randn(3, 100).cuda(), torch.randn(3, 5).cuda()  # 定义网络输入输出
myNet = torch.nn.Linear(100, 5).cuda()  # 实例化网络,一个全连接层optimizer = torch.optim.SGD(myNet.parameters(), lr=LEARNING_RATE)  # 定义优化器
scaler = GradScaler()  # 梯度缩放for i in range(EPOCH):  # 训练with autocast():  # 设置混合精度运行y_pred = myNet(x)loss = mse_loss(y_pred, y)scaler.scale(loss).backward()  # 将张量乘以比例因子,反向传播scaler.step(optimizer)  # 将优化器的梯度张量除以比例因子。scaler.update()  # 更新比例因子

10. 技巧 9:分割训练过程

  • 如果训练的网络非常深,比如 resnet101 就是一个很深的网络,直接训练深度神经网络对显存的要求非常高,一般一次无法直接训练整个网络。在这种情况下,可以将复杂网络分割为两个小网络,分别进行训练。
  • checkpoint 是 pytorch 中一种用时间换空间的显存不足解决方案,这种方法本质上减少的是参与一次训练网络整体的参数量,如下是一个实例代码。
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint# 自定义函数
def conv(inplanes, outplanes, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding),nn.BatchNorm2d(outplanes),nn.ReLU())class Net(nn.Module):  # 自定义网络结构,分为三个子网络def __init__(self):super().__init__()self.conv0 = conv(3, 32, 3, 1, 1)self.conv1 = conv(32, 32, 3, 1, 1)self.conv2 = conv(32, 64, 3, 1, 1)self.conv3 = conv(64, 64, 3, 1, 1)self.conv4 = nn.Linear(64, 10)  # 全连接层def segment0(self, x):  # 子网络1x = self.conv0(x)return xdef segment1(self, x):  # 子网络2x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)return xdef segment2(self, x):  # 子网络3x = self.conv4(x)return xdef forward(self, x):x = checkpoint(self.segment0, x)  # 使用 checkpointx = checkpoint(self.segment1, x)x = checkpoint(self.segment2, x)return x
  • 使用 checkpoint 进行网络训练要求输入属性 requires_grad=True ,在给出的代码中将一个网络结构拆分为 3 个子网络进行训练,对于没有 nn.Sequential() 构建神经网络的情况无非就是自定义的子网络里多几项,或者像例子中一样单独构建网络块。
  • 对于由 nn.Sequential() 包含的大网络块 (小网络块时没必要),可以使用 checkpoint_sequential 包来简化实现,具体实现过程如下:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequentialclass Net(nn.Module):  # 自定义网络结构,分为三个子网络def __init__(self):super().__init__()linear = [nn.Linear(10, 10) for _ in range(100)]self.conv = nn.Sequential(*linear)  # 网络主体,100 个全连接层def forward(self, x):num_segments = 2  # 拆分为两段x = checkpoint_sequential(self.conv, num_segments, x)return x

11. 技巧10:清理内存垃圾

  • python 中定义的变量一般在使用结束时不会立即释放资源,在训练循环开始时可以利用如下代码来回收内存垃圾。
import gc 
gc.collect()  # 清理内存

12. 技巧11:使用梯度累积

  • 由于显存大小的限制,训练大型网络模型时无法使用较大的 batch_size ,而一般较大的 batch_size 能令网络模型更快收敛。
  • 梯度累积就是将多个 batch 计算得到的损失平均后累积再进行反向传播,类似于技巧 5 中拆分 batch 的思想(但技巧 5 是将大 batch 拆小,训练的依旧是大 batch,而梯度累积训练的是小 batch)。
  • 可以采用梯度累积的思想来模拟较大 batch_size 可以达到的效果,具体实现代码如下:
output = myNet(input_)  # 输入送入网络
loss = mse_loss(target, output)  # 计算损失
loss = loss / 4  # 累积 4 次梯度
loss.backward()  # 反向传播
if step % 4 == 0:  # 如果执行了 4 步optimizer.step()  # 更新网络参数optimizer.zero_grad()  # 优化器梯度清零

13. 技巧12:清除不必要梯度

在运行测试程序时不涉及到与梯度有关的操作,因此可以清楚不必要的梯度以节约显存,具体包括但不限于如下操作:

  • 用代码 model.eval() 将模型置于测试状态,不启用标准化和随机舍弃神经元等操作。
  • 测试代码放入上下文管理器 with torch.no_grad(): 中,不进行图构建等操作。
  • 在训练或测试每次循环开始时加梯度清零操作
myNet.zero_grad()  # 模型参数梯度清零
optimizer.zero_grad()  # 优化器参数梯度清零

14. 技巧13:周期清理显存

  • 同理也可以在训练每次循环开始时利用 pytorch 自带清理显存的代码来释放不用的显存资源。
torch.cuda.empty_cache()  # 释放显存

执行这条语句释放的显存资源在用 Nvidia-smi 命令查看时体现不出,但确实是已经释放。其实 pytorch 原则上是如果变量不再被引用会自动释放,所以这条语句可能没啥用,但个人觉得多少有点用。

15. 技巧14:多使用下采样

下采样从实现上来看类似池化,但不限于池化,其实也可以用步长大于 1 来代替池化等操作来进行下采样。从结果上来看就是通过下采样得到的特征图会缩小,特征图缩小自然参数量减少,进而节约显存,可以用如下两种方式实现:

nn.Conv2d(32, 32, 3, 2, 1)  # 步长大于 1 下采样nn.Conv2d(32, 32, 3, 1, 1)  # 卷积核接池化下采样
nn.MaxPool2d(2, 2)

16. 技巧15:删除无用变量

del 功能是彻底删除一个变量,要再使用必须重新创建,注意 del 删除的是一个变量而不是从内存中删除一个数据,这个数据有可能也被别的变量在引用,实现方法很简单,比如:

def forward(self, x):input_ = xx = F.relu_(self.conv1(x) + input_)x = F.relu_(self.conv2(x) + input_)x = F.relu_(self.conv3(x) + input_)del input_  # 删除变量 input_x = self.conv4(x)  # 输出层return x

17. 技巧16:改变优化器

进行网络训练时比较常用的优化器是 SGD 和 Adam,抛开训练最后的效果来谈,SGD 对于显存的占用相比 Adam 而言是比较小的,实在没有办法时可以尝试改变参数优化算法,两种优化算法的调用是相似的:

import torch.optim as optim
from torchvision.models import resnet18LEARNING_RATE = 1e-3  # 学习率
myNet = resnet18().cuda()  # 实例化网络optimizer_adam = optim.Adam(myNet.parameters(), lr=LEAENING_RATE)  #  adam 网络参数优化算法
optimizer_sgd = optim.SGD(myNet.parameters(), lr=LEAENING_RATE)  # sgd 网络参数优化算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657037.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R语言【taxlist】——subset():取taxlist对象的子集

Package taxlist version 0.2.4 Description taxlist对象的子集将通过逻辑操作或模式匹配来完成。子集可以引用包含在插槽taxonNames、taxonRelations或taxonTraits中的信息。 Usage ## S4 method for signature taxlist subset(x,subset,slot "names",keep_child…

【游戏服务器部署】幻兽帕鲁服务器一键部署保姆级教程,游戏私服还是自己搭建的香

在帕鲁的世界,你可以选择与神奇的生物「帕鲁」一同享受悠闲的生活,也可以投身于与偷猎者进行生死搏斗的冒险。帕鲁可以进行战斗、繁殖、协助你做农活,也可以为你在工厂工作。你也可以将它们进行售卖,或肢解后食用。—幻兽帕鲁 想要…

ThinkPHP5.0.0~5.0.23反序列化利用链分析

本次测试环境仍然是ThinkPHP v5.0.22版本,我们将分析其中存在的一条序列化链。 一道CTF题 这次以一道CTF题作为此次漏洞研究的开头。题中涉及PHP的死亡绕过技巧,是真实环境中存在的情况。 $payload; $filename$payload.468bc8d30505000a2d7d24702b2cda…

春季选品策略:如何在Shopee平台上脱颖而出

在Shopee平台上进行春季选品时,卖家需要制定有效的策略来吸引消费者的注意并提高销售业绩。本文将介绍一些关键的选品策略,帮助卖家在春季市场中脱颖而出。 先给大家推荐一款shopee知虾数据运营工具知虾免费体验地址(复制浏览器打开&#xf…

MyBatis 源码系列:MyBatis 解析配置文件、二级缓存、SQL

文章目录 解析全局配置文件二级缓存解析解析二级缓存缓存中的调用过程缓存中使用的设计模式 解析SQL 解析全局配置文件 启动流程分析 String resource "mybatis-config.xml"; //将XML配置文件构建为Configuration配置类 reader Resources.getResourceAsReader(re…

探索ESP32 C++ OOP开发:与传统面向过程编程的比较

探索ESP32 OOP开发:与传统面向过程编程的比较 在嵌入式系统开发中,ESP32是一个强大的平台,可以应用于各种项目和应用场景。在编写ESP32代码时,我们可以选择使用面向对象编程(OOP)的方法,将代码…

数据结构—栈实现前缀表达式的计算

前缀表达式计算 过程分析 中缀表达式:(1 5)*3 > 前缀表达式:*153 (可参考这篇文章:中缀转前缀) 第一步:从右至左扫描前缀表达式(已存放在字符数组中)&a…

termux 玩法(一)

termux基础 termux基础玩法推荐国光写的手册:Termux 高级终端安装使用配置教程 | 国光 (sqlsec.com) termux安装 个人使用F-Droid安装的termux:Termux | F-Droid - Free and Open Source Android App Repository 基础知识 这些基础知识简单了解一下…

自定义模块加载(Python)

加载自定义模块,系统抛出“找不到文件”异常提示信息。 (笔记模板由python脚本于2024年01月28日 12:50:00创建,本篇笔记适合初通Python的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免…

LeetCode316. Remove Duplicate Letters——单调栈

文章目录 一、题目二、题解 一、题目 Given a string s, remove duplicate letters so that every letter appears once and only once. You must make sure your result is the smallest in lexicographical order among all possible results. Example 1: Input: s “bca…

盲人程序员是怎么编程的?闭眼编程

解决 读代码,读键盘变成听。 盲人程序员可以借助屏幕阅读器来使用计算机,绝大多数编程工具也可以正常访问,所以,盲人掌握编程语言是没有问题的。 具体的工作流程如下: 使用屏幕阅读器来“阅读”屏幕上的文本和代码。…

链表——超详细

一、无头单向非循环链表 1.结构(两个部分): typedef int SLTDataType; typedef struct SListNode {SLTDataType data;//数据域struct SListNode* next;//指针域 }SLNode; 它只有一个数字域和一个指针域,里面数据域就是所存放的…

3分钟搞定springboot 定时任务cron表达式

在开发过程中经常需要使用定时任务在特定的时间执行一些特定程序。而 springboot Scheduled注解中可以方便的使用 cron 表达式来配置定时任务。在这SpringBoot 实现定时任务一篇文章中我们介绍了如何使用Scheduled实现定时任务,下面我们看下cron该如何编写。 cron表…

万户 ezOFFICE wf_accessory_delete.jsp SQL注入漏洞复现

0x01 产品简介 万户OA ezoffice是万户网络协同办公产品多年来一直将主要精力致力于中高端市场的一款OA协同办公软件产品,统一的基础管理平台,实现用户数据统一管理、权限统一分配、身份统一认证。统一规划门户网站群和协同办公平台,将外网信息维护、客户服务、互动交流和日…

继电器模块详解

继电器,一种常见的电控制装置,其应用几乎无处不在。在家庭生活,继电器被广泛应用于照明系统、电视机、空调等电器设备的控制;在工业领域,它们用于控制电机、泵站、生产线等高功率设备的运行;继电器还在通信…

【论文收集】

Collaborative Diffusion for Multi-Modal Face Generation and Editing https://arxiv.org/abs/2304.10530 code:https://github.com/ziqihuangg/collaborative-diffusion 现有的扩散模型主要集中在单模态控制上,即扩散过程仅由一种状态模态驱动。为…

Docker的使用方式

一、Docker概念 Docker类似于一个轻量的虚拟机。 容器和镜像是Docker中最重要的两个概念,镜像可以保存为tar文件,Dockerfile是配置文件,仓库保存了很多第三方已经做好的镜像。 基本指令 查找镜像 docker search nginx 拉取nginx镜像 do…

携程获取景点详情 API 返回值说明

公共参数 请求地址:​​前往测试​​ 名称 类型 必须 描述 key String 是 调用key,必须以GET方式拼接在URL中) secret String 是 调用密钥 api_name String 是 API接口名称(包括在请求地址中)[item_se…

搭建Jmeter分布式压测与监控,轻松实践

对于运维工程师来说,需要对自己维护的服务器性能瓶颈了如指掌,比如我当前的架构每秒并发是多少,我服务器最大能接受的并发是多少,是什么导致我的性能有问题;如果当前架构快达到性能瓶颈了,是横向扩容性能提…

香港服务器IP段4c和8c的区别及SEO选择建议

随着互联网的快速发展,服务器IP段的选择对于网站SEO优化至关重要。香港服务器IP段4C和8C是两种常见的IP段,它们在SEO优化中具有不同的特点和优势。本文将详细介绍这两种IP段的区别,并给出相应的SEO选择建议。 一、香港服务器IP段4C和8C的区别…