大语言模型的工程技巧(四)——梯度检查点

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。

本文将讨论如何利用梯度检查点算法来减少模型在训练时候(更准确地说是运行反向传播算法时)的内存开支。这在训练超大规模的模型时会用到。

关于其他的工程技巧可以参考:

  • 大语言模型的工程技巧(一)——GPU计算
  • 大语言模型的工程技巧(二)——混合精度训练
  • 大语言模型的工程技巧(三)——分布式计算

关于大语言模型的讨论请参考:

  • 理解大语言模型(二)——从零开始实现GPT-2

内容大纲

  • 相关说明
  • 一、标准反向传播
  • 二、内存极简算法
  • 三、梯度检查点

一、标准反向传播

根据梯度的定义,变量的梯度与其本身的值密切相关。因此,要想得到某个变量的梯度,必须先知道这个变量的值。这也是为什么在进行反向传播算法之前,需要先对计算图进行向前传播,并记录每个节点的计算结果,如图1左侧部分所示。这样在计算节点的梯度时,可以利用这些事先缓存的结果,直接启动反向传播过程,从而得到梯度,如图1中的节点d所示。这种方法也被称为标准反向传播。这种方式能够确保梯度计算以最高效的方式进行。

图1

图1

二、内存极简算法

然而,采用标准反向传播算法会造成较大的内存开销。为了在计算过程中尽可能地压缩内存使用,可以采用一种以时间换空间的方法。在这种算法中,一旦向前传播完成,仅会保留顶点的计算结果,而中间节点的结果会被清空(叶子节点的值会保留)。在反向传播遇到中间计算节点没有缓存时,则重新触发向前传播,以获取所需节点的结果。这就是内存极简的反向传播算法。以节点d为例,为了计算其梯度,需要首先从节点a开始重新触发向前传播直到节点d,并缓存计算结果。然后使用这个缓存的结果以及节点e的梯度,计算出节点d的梯度。对于其他节点,也采用类似的步骤计算梯度。通过这种方式,在完成反向传播的同时,节省了内存开销。以图1为例,内存极简算法只需要3个存储空间,而标准算法需要5个存储空间。

三、梯度检查点

尽管内存极简算法在降低内存开销方面取得了显著成果,但它涉及大量的重复计算,运行时间相对较长。为了在内存使用和运行时间之间取得平衡,下面引入梯度检查点(Gradient Checkpoint)。这一算法的核心思想是选择一些中间节点作为存储点,以便在再次触发向前传播时,以这些存储点作为起点开始传播,避免从头开始重复计算。这种方式在一定程度上减少重复计算,从而提高运行效率。需要注意的是,由于需要存储额外的中间结果,梯度检查点会稍微增加一些内存开销。

关于梯度检查点算法,PyTorch中已经提供了便捷的封装函数,即torch.utils.checkpoint。这个工具能够帮助我们更方便地应用梯度检查点算法,以平衡内存开锁和运行时间。更多细节请参考这个链接。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/15029.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python集合与字典的概念与使用-课后作业[python123题库]

集合与字典的概念与使用-课后作业 一、单项选择题 1、S和T是两个集合,哪个选项对S^T的描述是正确的?‪‬‪‬‪‬‪‬‪‬‮‬‪‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‮‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‭‬‪‬‪‬‪‬…

机器学习-决策树算法

前言 本篇介绍决策树与随机森林的内容,先完成了决策树的部分。 决策树 决策树(Decision Tree)是一种有监督学习的方法,可以同时解决分类和回归问题,它能够从一系列有特征和标签的数据中总结出决策规则,并用树状图的结构来呈现这…

SecureCRT for Mac注册激活版:专业终端SSH工具

SecureCRT是一款支持SSH(SSH1和SSH2)的终端仿真程序,简单地说是Windows下登录UNIX或Linux服务器主机的软件。 SecureCRT支持SSH,同时支持Telnet和rlogin协议。SecureCRT是一款用于连接运行包括Windows、UNIX和VMS的理想工具。通过…

杭州师范大学物理学院女教授贾婷

主持国家自然科学基金两项,安徽省基金一项。 2003年9月--2007年7月,四川师范大学物理系,学士; 2007年7月--2012年7月,中科院合肥研究院固体物理研究所,博士; 2012年7月—2015年3月&#xff0…

大摩:AI到“临界点”了,资管公司到了广泛部署的时刻

大摩表示,尽管AI技术在资产管理行业中的应用仍处于早期阶段,但其潜力巨大,能够为行业带来根本性的变革。预计生成式AI能够在资产管理公司的运营模型中带来20%至40%的生产力提升。 正文介绍 在全球经济面临诸多不确定因素的当下,…

【全开源】答题考试系统源码(FastAdmin+ThinkPHP+Uniapp)

答题考试系统源码:构建高效、安全的在线考试平台 引言 在当今数字化时代,在线考试系统已成为教育机构和企业选拔人才的重要工具。一个稳定、高效、安全的答题考试系统源码是构建这样平台的核心。本文将深入探讨答题考试系统源码的关键要素,…

大佬大讲堂(1)电机及其驱动内核-自适应观察器

点击上方 “机械电气电机杂谈 ” → 点击右上角“...” → 点选“设为星标 ★”,为加上机械电气电机杂谈星标,以后找夏老师就方便啦!你的星标就是我更新动力,星标越多,更新越快,干货越多! 关注…

Java面试八股之可重入锁ReentrantLock是怎么实现可重入的

可重入锁ReentrantLock是怎么实现可重入的 ReentrantLock实现可重入性的机制主要依赖于以下几个核心组件和步骤: 状态计数器:ReentrantLock内部维护一个名为state的整型变量作为状态计数器,这个计数器不仅用来记录锁是否被持有,…

Java进阶学习笔记9——子类中访问其他成员遵循就近原则

正确访问成员的方法。 在子类方法中访问其他成员(成员变量、成员方法),是依照就近原则的。 F类: package cn.ensource.d13_extends_visit;public class F {String name "父类名字";public void print() {System.out.p…

langchian进阶二:LCEL表达式,轻松进行chain的组装

LangChain表达式语言-LCEL,是一种声明式的方式,可以轻松地将链条组合在一起。 你会在这些情况下使用到LCEL表达式: 流式支持 当你用LCEL构建你的链时,你可以得到最佳的首次到令牌的时间(输出的第一块内容出来之前的时间)。对于一些链&#…

Springboot+Vue项目-基于Java+MySQL的酒店管理系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

因说错一句话,京东实习生转正12天被辞退

大家好,我是程序员小灰。 就在昨天,我的老东家京东被爆出了一个大瓜,有一个刚刚加入京东的应届生,仅仅转正12天就被辞退了。 是谁这么倒霉?为什么被辞退呢? 事情的起因是京东新近推出的考评制度。根据京东的…

Spring Boot集成Banner快速入门demo

1.banner介绍 Spring Boot Banner 是一个用于在应用程序启动时显示自定义 ASCII 艺术和信息的功能。这个 ASCII 艺术通常包括项目名称、版本号、作者信息等。Banner 的主要作用是增强应用程序的品牌标识,同时提供一种友好的欢迎方式,让用户或开发人员在…

手撕算法|斯坦福大学教授用60页PPT搞定了八大神经网络

人工智能领域深度学习的八大神经网络常见的是以下几种 1.卷积神经网络(CNN): 卷积神经网络是用于图像和空间数据处理的神经网络,通过卷积层和池化层来捕捉图像的局部特征,广泛应用于图像分类、物体检测等领域。 2.循…

blender 布尔运算,切割模型。

1.创建一个立方体和球体。 2.选中立方体,在属性面板添加布尔修改器。点击物体属性右边的按钮选中球体。参数如下。 3.此时隐藏球体,就可以看到被切掉的效果了。

【算法】前缀和算法——和可被K整除的子数组

题解:和可被K整除的子数组(前缀和算法) 目录 1.题目2.前置知识2.1同余定理2.2CPP中‘%’的计算方式与数学‘%’的差异 及其 修正2.3题目思路 3.代码示例4.总结 1.题目 题目链接:LINK 2.前置知识 2.1同余定理 注:这里的‘/’代表的是数学…

MySQL云数据库5.5导入到自建MySQL数据库5.7补充

之前写过一篇关于MySQL云数据库5.5导入到自建MySQL数据库5.7的文章。 最近又搞了一次MySQL5.5导入5.7,发现有些细节需要补充。 这里mysql都是linxu系统的服务器。 1.使用mysqldump备份MySQL5.5再使用mysql导入mysql会报错。 解决方法:分两步走。 先使用mysqldu…

Creating Server TCP listening socket *:6379: listen: Unknown error

错误: 解决方法: 在redis安装路径中打开cmd命令行窗口,输入 E:\Redis-x64-3.2.100>redis-server ./redis.windows.conf结果:

动态链接学习总结

背景 之前了解了静态链接的原理,就想着把动态链接的原理也学习一下,提高编程能力。 关键知识点 动态链接的工作原理: 编译时的处理: 当程序被编译时,编译器知道程序需要某些库函数,但并不把这些函数的代…

JAVA学习-练习试用Java实现“螺旋矩阵”

问题: 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: 输入:matrix [[1,2,3],[4,5,6],[7,8,9]] 输出:[1,2,3,6,9,8,7,4,5] 示例 2: 输入&#…