DLA 神经网络的极限训练方法:gradient checkpointing

gradient checkpointing

        一般来说,训练的过程需要保存中间结果(不管是GPU还是CPU)。前向传播根据输入(bottom_data)计算输出(top_data),后向传播由top_diff计算bottom_diff(如果某个变量打开梯度进行训练的话)。top和bottom是包含数据和梯度的两个结构体,整个网络的每层top和bottom在训练的过程中都会保存,这消耗了大量的内存。

        如果不保存这些变量,每次传播时重新分配和计算,会大大减少内存的使用量,但是也会使得网络的训练时间无限延长。为了平衡这两个矛盾,论文Training Deep Nets with Sublinear Memory Cost 使用亚线性内存成本训练深度网络:我们提出了一种系统方法来减少深度的内存消耗 神经网络训练。具体来说,我们设计了一种成本高昂的算法 O(sqrt(n)) 内存来训练 n 层网络,只需计算成本 每个小批量的额外前向传递。每隔 sqrt(n)保留一个检查点的feature map。

CODE

  • https://pytorch.org/docs/stable/checkpoint.html
// https://discuss.pytorch.org/t/trying-to-understand-torch-utils-checkpoint/95224
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdmfrom torch import optim
import torchvision.models as models
from torch import nnCHECKPOINT = True
BATCH_SIZE = 32
dev = "cuda:0"class ImageDataset(Dataset):def __init__(self,length = 100000,size = 244):self.length = lengthself.size = 244def __len__(self):return self.lengthdef __getitem__(self,idx,display = False):return torch.from_numpy(np.random.randn(2,3,self.size,self.size))
train = ImageDataset()
trainloader = DataLoader(train,batch_size = BATCH_SIZE,num_workers = 24,pin_memory = True
)resnet = models.resnet50(pretrained = False)class MODEL(nn.Module):def __init__(self,model):super(MODEL,self).__init__()self.model = modelself.LR = nn.Linear(1000,1000)def forward(self,x):if CHECKPOINT == False:o1 = self.model(x[:,0])o2 = self.model(x[:,1])else:o1 = torch.utils.checkpoint.checkpoint(self.model,x[:,0])o2 = torch.utils.checkpoint.checkpoint(self.model,x[:,1])return torch.mean((self.LR(o1)-o2)**2)resnet = MODEL(resnet).to(dev)optimizer = optim.SGD(resnet.parameters(),lr = .001)for T in tqdm(trainloader):out = torch.mean(resnet(T.float().to(dev)))optimizer.zero_grad()out.backward()optimizer.step()

CG

在这里插入图片描述

  • https://github.com/merrymercy/dtr-prototype

ZeRO-Offload

  • https://arxiv.org/pdf/2101.06840.pdf 大规模模型训练一直是少数人的比赛场地 需要复杂的模型重构和访问昂贵的 GPU 集群。ZeRO-Offload 通过使 几乎每个人都可以访问大型模型训练。它可以训练模型 单个 GPU 上超过 13 亿个参数,与 GPU 相比,大小增加了 10 倍 流行的框架,如PyTorch,它不需要任何模型就可以做到这一点。 从数据科学家改变或牺牲计算效率。 ZeRO-卸载通过卸载数据和计算来实现大型模型训练 中央处理器。为了保持计算效率,它旨在最大限度地减少数据 移入/移出 GPU,减少 CPU 计算时间,同时最大化内存 节省 GPU 成本。因此,ZeRO-Offload可以在单个上实现40 TFlops / GPU。 NVIDIA V100 GPU 用于 10B 参数模型,与单独使用 PyTorch 的 30TF 相比 对于 1.4B 参数模型,可以训练而不会耗尽的最大参数模型 的记忆。ZeRO-Offload 还设计为在以下情况下在多个 GPU 上进行扩展 可用,可在多达 128 个 GPU 上提供近乎线性的加速。此外,它可以 与模型并行性协同工作,训练超过 70 亿的模型 单个 DGX-2 盒子上的参数,与模型尺寸相比增加了 4.5 倍 单独使用模型并行性。通过将计算和内存效率与 易于使用,ZeRO-Offload 使大规模模型训练民主化,使其成为 即使是数据科学家也可以访问,只需访问一个 GPU。

梯度累积

        训练时大的batch一般能得到更稳定的训练效果,梯度累积训练方法是一种用于训练深度神经网络的技术,旨在减少显存需求并提高训练效果。在传统的训练方法中,模型的参数是通过单个批次(batch)的数据计算得到的梯度平均值进行更新。但在梯度累积训练中,模型的参数更新是通过多个批次的梯度累积得到的。

以下是梯度累积训练的基本步骤:

  1. 设置梯度累积步数(accumulation steps),它决定了要累积多少个批次的梯度。

  2. 初始化模型的参数。

  3. 对于每个训练批次(batch):

    • 使用当前批次的数据进行前向传播计算损失。
    • 对损失进行反向传播计算梯度。
    • 累积当前批次的梯度到之前的梯度值上。
  4. 当累积达到设置的步数时,将累积的梯度应用于模型参数的更新:

    • 通过将累积的梯度平均化来获得参数的更新值。
    • 使用更新值来更新模型的参数。
  5. 重复步骤3和4,直到完成所有的训练批次。

梯度累积训练的主要优势在于能够降低每个批次所需的显存量,允许在具有有限显存的硬件上训练更大的模型。此外,梯度累积还可以改善模型的收敛性,提高模型的性能和泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/23755.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5个顶级的开源有限元分析软件

每当我参加数值分析课程的教学时,都会回顾有限元方法的基础知识,很自然地就会出现使用哪种软件的问题。 以下讨论基于三个基本考虑: 在实际应用中,很少有人从头开始编写 FEM 代码。商业 FEM 软件通常在某些预定义的情况下非常易于…

Pandas操作Excel

Pandas 是 Python 语言的一个扩展程序库,用于数据分析。 菜鸟教程:https://www.runoob.com/pandas/pandas-tutorial.html 读取Excel pd.read_excel(path,sheet_name,header) path:excel文件路径sheet_name:读取的sheet&#xff0…

3.netty和protobuf

1.ChannelGroup可以免遍历由netty提供,覆盖remove方法即可触发删除channel\ 2.群聊私聊 13.群聊私聊简单原理图 3.netty心跳检测机制,客户端对服务器有没有读写(读,写空闲) //IdleStateHandler(3,5,7,TimeUnite.SECONDS)是netty提供的检测状态的处理器,也加到pipeline,读,写,…

浅析 C 语言的共用体、枚举和位域

前言 最近在尝试阅读一些系统库的源码,但是其中存在很多让我感到既熟悉又陌生的语法。经过资料查阅,发现是 C 语言中的共用体和位域。于是,趁着课本还没有扔掉,将一些相关的知识点记录在本文。 文章目录 前言共用体 (union)枚举…

网络开发-IO模型

基本概念 I/O即数据的读取&#xff08;接收&#xff09;或写入&#xff08;发送&#xff09;操作 通常用户进程中的一个完整I/O分为两个阶段 用户进程空间<-->内核空间内核空间<-->设备空间&#xff08;磁盘、网卡等&#xff09; I/O分为内存I/O、网络I/O和磁盘…

【编程】典型题目:寻找数组第K大数(四种方法对比)

【编程】典型题目&#xff1a;寻找数组第K大数&#xff08;四种方法对比&#xff09; 文章目录 【编程】典型题目&#xff1a;寻找数组第K大数&#xff08;四种方法对比&#xff09;1. 题目2. 题解2.1 方法一&#xff1a;全局排序&#xff08;粗暴&#xff09;2.2 方法二&#…

123.买卖股票的最佳时机3

目录 一、题目 二、分析代码 一、题目 123. 买卖股票的最佳时机 III - 力扣&#xff08;LeetCode&#xff09; 二、分析代码 class Solution { public:int maxProfit(vector<int>& prices) {//0表示没有操作//1表示第1次买入&#xff0c;2表示第1次卖出//3表示第2…

用html+javascript打造公文一键排版系统11:改进单一附件说明排版

一、用htmljavascript打造公文一键排版系统10中的一个bug 在 用htmljavascript打造公文一键排版系统10&#xff1a;单一附件说明排版 中&#xff0c;我们对附件说明的排版函数是&#xff1a; function setAtttDescFmt(p) {var t p;var a ;if (-1 ! t.indexOf(:))//是半角冒…

学习源码,模仿编程

一.观察者模式: 1.创建事件 2.发布事件 3.监听事件 4.效果: 二.模板方法模式

FTP使用教程

FTP使用教程 目录 一&#xff0e;FTP简介二&#xff0e;FTP搭建三&#xff0e;FTP使用 一&#xff0e;FTP简介 FTP中文为文件传输协议&#xff0c;简称为文传协议。它也是一个应用程序&#xff0c;不同的操作系统有不同的FTP应用程序&#xff0c;这些应用程序都遵守同一种协议以…

LeetCode724. 寻找数组的中心下标

题干 给你一个整数数组 nums &#xff0c;请计算数组的 中心下标 。 数组 中心下标 是数组的一个下标&#xff0c;其左侧所有元素相加的和等于右侧所有元素相加的和。 如果中心下标位于数组最左端&#xff0c;那么左侧数之和视为 0 &#xff0c;因为在下标的左侧不存在元素。…

k8s概念-pv和pvc

回到目录 kubernetes存储卷的分类太丰富了,每种类型都要写相应的接口与参数才行&#xff0c;这就让维护与管理难度加大。 persistenvolume(PV) 是配置好的一段存储(可以是任意类型的存储卷) 也就是说将网络存储共享出来,配置定义成PV。 PersistentVolumeClaim(PVC)是用户pod使…

谁更适合搭配甜点显卡?i7-13700KF、锐龙7 7800X3D对比:游戏相当 生产力Intel强了50%...

一、前言&#xff1a;如果搭配2000元甜点显卡 i7-13700KF和锐龙7 7800X3D谁更有性价比&#xff1f; 现在AMD最受欢迎的处理器无疑是拥有96MB三级缓存的锐龙7 7800X3D&#xff0c;这是一颗专为游戏而生的处理器。 Intel这边&#xff0c;i7-13700KF以略高于i5-13600K的售价&#…

小鱼深度产品测评之:阿里云容器服务器ASK,一款不需购买节点,即可直接部署容器应用。

容器服务器ASK测评 1、引言2、帮助文档3、集群3.1集群列表3.1.1 详情3.1.1.1概览 tab3.1.1.2基本信息 tab3.1.1.4集群资源 tab3.1.1.5 集群日志 tab3.1.1.6 集群任务 tab 3.1.2 应用管理3.1.2.1 详情3.1.2.2 详情3.1.2.3 伸缩3.1.2.4 监控 3.1.3 查看日志3.1.3.1 集群日志3.1.3…

Julia 日期和时间

Julia 通过 Dates 模块提供了以下三个函数来处理日期和时间&#xff1a; Date&#xff1a;表示日期&#xff0c;精确到日&#xff0c;只显示日期。DateTime&#xff1a;表示日期和时间&#xff0c;精确到毫秒。DateTime&#xff1a;表示日时间&#xff0c;精确到纳秒&#xff…

【Python】数据分析+数据挖掘——掌握Python和Pandas中的单元格替换操作

1. 前言 数据处理和清洗是数据分析和机器学习中至关重要的步骤。在数据处理过程中&#xff0c;我们经常需要对数据集进行清洗和转换&#xff0c;其中单元格替换是一个常用的技术。Python作为一种功能强大且灵活的编程语言&#xff0c;为数据处理提供了丰富的工具和库。Pandas库…

邀请媒体现场报道,有哪些作用?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 邀请媒体现场报道活动具有多种重要作用和意义&#xff0c;可以为你的活动带来广泛的曝光和正面影响。以下是一些邀请媒体现场报道的作用和意义&#xff1a; 1. 增加活动曝光度&#xff…

通过nvm工具快捷切换node.js版本、以及nvm的安装

使用nvm可以实现多个Node.js版本之间切换 步骤目录&#xff1a; 先卸载掉本系统中原有的node版本 去github上下载nvm安装包 安装node 常用的一些nvm命令 1、先卸载掉本系统中原有的node版本 2、去github上下载nvm安装包 https://github.com/coreybutler/nvm-windows/re…

LeetCode933. 最近的请求次数

题干 写一个 RecentCounter 类来计算特定时间范围内最近的请求。 请你实现 RecentCounter 类&#xff1a; RecentCounter() 初始化计数器&#xff0c;请求数为 0 。int ping(int t) 在时间 t 添加一个新请求&#xff0c;其中 t 表示以毫秒为单位的某个时间&#xff0c;并返回…

Python学习笔记:List、Tuple、for循环

1.list list_demo [7, 7, 8, 9, 9, 9] print(list_demo.index(7)) # index 方法返回第一个index list_demo.sort() # 排序list list_demo.reverse() # 倒序list list_demo1 list_demo.copy() # 复制list 2.matrix 其实就是list嵌套&…