大语言模型的工程技巧(四)——梯度检查点

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。

本文将讨论如何利用梯度检查点算法来减少模型在训练时候(更准确地说是运行反向传播算法时)的内存开支。这在训练超大规模的模型时会用到。

关于其他的工程技巧可以参考:

  • 大语言模型的工程技巧(一)——GPU计算
  • 大语言模型的工程技巧(二)——混合精度训练
  • 大语言模型的工程技巧(三)——分布式计算

关于大语言模型的讨论请参考:

  • 理解大语言模型(二)——从零开始实现GPT-2

内容大纲

  • 相关说明
  • 一、标准反向传播
  • 二、内存极简算法
  • 三、梯度检查点

一、标准反向传播

根据梯度的定义,变量的梯度与其本身的值密切相关。因此,要想得到某个变量的梯度,必须先知道这个变量的值。这也是为什么在进行反向传播算法之前,需要先对计算图进行向前传播,并记录每个节点的计算结果,如图1左侧部分所示。这样在计算节点的梯度时,可以利用这些事先缓存的结果,直接启动反向传播过程,从而得到梯度,如图1中的节点d所示。这种方法也被称为标准反向传播。这种方式能够确保梯度计算以最高效的方式进行。

图1

图1

二、内存极简算法

然而,采用标准反向传播算法会造成较大的内存开销。为了在计算过程中尽可能地压缩内存使用,可以采用一种以时间换空间的方法。在这种算法中,一旦向前传播完成,仅会保留顶点的计算结果,而中间节点的结果会被清空(叶子节点的值会保留)。在反向传播遇到中间计算节点没有缓存时,则重新触发向前传播,以获取所需节点的结果。这就是内存极简的反向传播算法。以节点d为例,为了计算其梯度,需要首先从节点a开始重新触发向前传播直到节点d,并缓存计算结果。然后使用这个缓存的结果以及节点e的梯度,计算出节点d的梯度。对于其他节点,也采用类似的步骤计算梯度。通过这种方式,在完成反向传播的同时,节省了内存开销。以图1为例,内存极简算法只需要3个存储空间,而标准算法需要5个存储空间。

三、梯度检查点

尽管内存极简算法在降低内存开销方面取得了显著成果,但它涉及大量的重复计算,运行时间相对较长。为了在内存使用和运行时间之间取得平衡,下面引入梯度检查点(Gradient Checkpoint)。这一算法的核心思想是选择一些中间节点作为存储点,以便在再次触发向前传播时,以这些存储点作为起点开始传播,避免从头开始重复计算。这种方式在一定程度上减少重复计算,从而提高运行效率。需要注意的是,由于需要存储额外的中间结果,梯度检查点会稍微增加一些内存开销。

关于梯度检查点算法,PyTorch中已经提供了便捷的封装函数,即torch.utils.checkpoint。这个工具能够帮助我们更方便地应用梯度检查点算法,以平衡内存开锁和运行时间。更多细节请参考这个链接。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/15029.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习-决策树算法

前言 本篇介绍决策树与随机森林的内容,先完成了决策树的部分。 决策树 决策树(Decision Tree)是一种有监督学习的方法,可以同时解决分类和回归问题,它能够从一系列有特征和标签的数据中总结出决策规则,并用树状图的结构来呈现这…

SecureCRT for Mac注册激活版:专业终端SSH工具

SecureCRT是一款支持SSH(SSH1和SSH2)的终端仿真程序,简单地说是Windows下登录UNIX或Linux服务器主机的软件。 SecureCRT支持SSH,同时支持Telnet和rlogin协议。SecureCRT是一款用于连接运行包括Windows、UNIX和VMS的理想工具。通过…

大摩:AI到“临界点”了,资管公司到了广泛部署的时刻

大摩表示,尽管AI技术在资产管理行业中的应用仍处于早期阶段,但其潜力巨大,能够为行业带来根本性的变革。预计生成式AI能够在资产管理公司的运营模型中带来20%至40%的生产力提升。 正文介绍 在全球经济面临诸多不确定因素的当下,…

【全开源】答题考试系统源码(FastAdmin+ThinkPHP+Uniapp)

答题考试系统源码:构建高效、安全的在线考试平台 引言 在当今数字化时代,在线考试系统已成为教育机构和企业选拔人才的重要工具。一个稳定、高效、安全的答题考试系统源码是构建这样平台的核心。本文将深入探讨答题考试系统源码的关键要素,…

大佬大讲堂(1)电机及其驱动内核-自适应观察器

点击上方 “机械电气电机杂谈 ” → 点击右上角“...” → 点选“设为星标 ★”,为加上机械电气电机杂谈星标,以后找夏老师就方便啦!你的星标就是我更新动力,星标越多,更新越快,干货越多! 关注…

Java面试八股之可重入锁ReentrantLock是怎么实现可重入的

可重入锁ReentrantLock是怎么实现可重入的 ReentrantLock实现可重入性的机制主要依赖于以下几个核心组件和步骤: 状态计数器:ReentrantLock内部维护一个名为state的整型变量作为状态计数器,这个计数器不仅用来记录锁是否被持有,…

Java进阶学习笔记9——子类中访问其他成员遵循就近原则

正确访问成员的方法。 在子类方法中访问其他成员(成员变量、成员方法),是依照就近原则的。 F类: package cn.ensource.d13_extends_visit;public class F {String name "父类名字";public void print() {System.out.p…

langchian进阶二:LCEL表达式,轻松进行chain的组装

LangChain表达式语言-LCEL,是一种声明式的方式,可以轻松地将链条组合在一起。 你会在这些情况下使用到LCEL表达式: 流式支持 当你用LCEL构建你的链时,你可以得到最佳的首次到令牌的时间(输出的第一块内容出来之前的时间)。对于一些链&#…

Springboot+Vue项目-基于Java+MySQL的酒店管理系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

手撕算法|斯坦福大学教授用60页PPT搞定了八大神经网络

人工智能领域深度学习的八大神经网络常见的是以下几种 1.卷积神经网络(CNN): 卷积神经网络是用于图像和空间数据处理的神经网络,通过卷积层和池化层来捕捉图像的局部特征,广泛应用于图像分类、物体检测等领域。 2.循…

blender 布尔运算,切割模型。

1.创建一个立方体和球体。 2.选中立方体,在属性面板添加布尔修改器。点击物体属性右边的按钮选中球体。参数如下。 3.此时隐藏球体,就可以看到被切掉的效果了。

【算法】前缀和算法——和可被K整除的子数组

题解:和可被K整除的子数组(前缀和算法) 目录 1.题目2.前置知识2.1同余定理2.2CPP中‘%’的计算方式与数学‘%’的差异 及其 修正2.3题目思路 3.代码示例4.总结 1.题目 题目链接:LINK 2.前置知识 2.1同余定理 注:这里的‘/’代表的是数学…

Creating Server TCP listening socket *:6379: listen: Unknown error

错误: 解决方法: 在redis安装路径中打开cmd命令行窗口,输入 E:\Redis-x64-3.2.100>redis-server ./redis.windows.conf结果:

动态链接学习总结

背景 之前了解了静态链接的原理,就想着把动态链接的原理也学习一下,提高编程能力。 关键知识点 动态链接的工作原理: 编译时的处理: 当程序被编译时,编译器知道程序需要某些库函数,但并不把这些函数的代…

【C++】C++11(一)

C11是一次里程碑式的更新,我们一起来看一看~ 目录 列表初始化:{ }初始化:std::initializer_list: 声明:auto:decltype: STL的一些变化: 列表初始化: { }初始化&#xf…

学习记录16-反电动势

一、反电动势公式 在负载下反电势和端电压的关系式为:𝑈𝐼𝑅𝐿*(𝑑𝑖 / 𝑑𝑡)𝐸 E为线圈电动势、 𝜓 为磁链、f为频率、N…

博客说明 5/12~5/24【个人】

博客说明 5/12~5/24【个人】 前言版权博客说明 5/12~5/24【个人】对比最后 前言 2024-5-24 13:39:23 对我在2024年5月12日到5月24日发布的博客做一下简要的说明 以下内容源自《【个人】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作…

python水果分类字典构建指南

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、理解需求 三、构建字典 1. 数据结构选择 2. 代码实现 3. 结果展示 四、总…

【Sync FIFO介绍及基于Verilog的实现】

Sync FIFO介绍及实现 1 Intro2 Achieve2.1 DFD2.2 Intf2.3 Module 本篇博客介绍无论是编码过程中经常用到的逻辑–FIFO;该FIFO是基于单时钟下的同步FIFO; FiFO分类:同步FiFO VS 异步FiFO; 1 Intro FIFO可以自己实现,但…

使用pygame绘制图形

参考链接:https://www.geeksforgeeks.org/pygame-tutorial/?reflbp 在窗口中绘制单个图形 import pygame from pygame.locals import * import sys pygame.init()window pygame.display.set_mode((600,600)) window.fill((255,255,255))# pygame.draw.rect(wind…