大语言模型的工程技巧(二)——混合精度训练

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
混合精度训练的示例请参考如下链接:regression2chatgpt/ch11_llm/gpt2_lora_optimum.ipynb

本文将讨论如何利用混合精度训练(Mixed Precision Training)来减少内存的开销,特别是GPU内存的开销。这在大语言模型的训练当中是非常重要的。关于GPU的计算可以参考

  • 大语言模型的工程技巧(一)——GPU计算

关于大语言模型的讨论请参考:

  • 理解大语言模型(二)——从零开始实现GPT-2

内容大纲

  • 相关说明
  • 一、概述
  • 二、什么是混合精度训练?
  • 三、算法细节
  • 四、代码实现

一、概述

在人工智能领域,反向传播算法(计算参数梯度的算法)是非常重要的。而在进行反向传播计算时,必须将经过膨胀的计算图存储在内存中(如果使用GPU运算,那么将存储在GPU的专用内存中)。然而,这种存储量相当庞大,在整个计算图的存储结构中,数值存储占据了最大的比例。这些数值包括各个节点的计算结果(来自向前传播的输出),以及相应的梯度(这些梯度是来自反向传播的结果)。虽然梯度累积技术可以通过分解计算图来限制计算图的膨胀,从而降低内存的使用,但面对庞大的模型时,即便是单个数据点的计算图,其所需的内存都是巨大的。例如,大语言模型的参数数量可能高达数十亿甚至上百亿。

二、什么是混合精度训练?

为了解决这个具有挑战性的问题,需要采取额外的优化策略来降低内存的使用。在深入探讨这些策略之前,我们需要更详细地了解数字在计算机中的存储方式。一般而言,数值计算结果使用32位浮点数(需要4字节来存储,使用32位的二进制的方式表示)存储。这种存储方式被称为单精度浮点数。那么,如果使用16位二进制数表示一个数值,会产生什么影响呢?

这种方法的好处之一是能够立即减少所需的存储空间,同时提升计算速度。然而,这种方法也存在一个明显的缺陷,即能够表示的数值范围受限。为了便于讨论,下面以能够表示的最小正数为例。使用16位浮点数,能够表示的最小正数是 2 − 24 2^{-24} 224(相比之下,32位浮点数能够表示的最小正数为 2 − 149 2^{-149} 2149)。当实际的数值小于这个阈值时,计算机会错误地将其视作0,这就是浮点数下溢(Underflow)。

为了尽可能地减少这类错误的发生,可以混合精度训练(Mixed Precision Training)算法,顾名思义,它是指在模型训练过程中使用不同的数值精度来处理不同部分的计算。

三、算法细节

这一算法包含两个主要部分。

  1. 精度分层处理:在这种训练中,模型本身(模型参数)依然使用32位浮点数进行存储,参数更新过程也使用32位浮点数。在模型的向前传播和反向传播过程中,转而使用16位浮点数进行计算。具体情况如图1所示。

图1

图1

  1. 引入比例因子(Scale Factor):在数学上,要防止浮点数下溢是相当容易的,只需要将模型损失乘以一个较大的常数n,该常数也被称为比例因子。根据链式法则,这将导致所有节点的梯度都增大n倍。这种方法确保了梯度落入16位浮点数表示的范围,从而解决浮点数下溢问题。在使用这些梯度进行参数更新时,需要将引入的缩放移除,也就是将梯度除以n。将这个过程与精度分层处理相结合,如图2所示。

图2

图2

混合精度训练方法的优势在于,在保持适当的模型表示能力的同时,显著降低了内存开锁。通过将高精度的32位浮点数与16位浮点数的计算相结合,在不牺牲模型性能的前提下,显著减少内存需求,使计算机能够处理更大规模的模型和数据集。

四、代码实现

在实际应用中,PyTorch已经提供了相应的封装函数,分别是torch.cuda.amp.autocast和torch.cuda.amp.GradScaler。其中autocast实现的是第一部分——精度分层处理;GradScaler实现的是第二部分——引入比例因子。借助这两个工具,在优化算法中使用混合精度训练就变得很容易了。示意代码如下:

# 常规的模型训练实现
for epoch in range(0): for input, target in zip(data, targets):# 启动混合精度训练with torch.autocast(device_type=device, dtype=torch.float16):output = net(input)loss = loss_fn(output, target)# 在触发反向传播之前,启动缩放因子scaler.scale(loss).backward()# 更新模型参数scaler.step(opt)scaler.update()opt.zero_grad()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/16106.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java语法篇-易错

文章目录 类型转换switch case类之间关系及UMLtry catch finally 类型转换 隐式类型转换,不同数值类型参与计算时,低精度会转化为高精度参与运算 byte,short,char参与整数运算时会转成int float,int 参与浮点数运算时会转成double 强制类型转换 高精…

数据结构 —— 栈 与 队列

1.栈 1.1栈的结构和概念 栈(Stack)是一种特殊的线性数据结构,它遵循后进先出(LIFO,Last In First Out)的原则。栈只允许在一端插入和删除数据,这一端被称为栈顶(top)&a…

c++引用和内联函数

一、引用 1.引用概念 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开辟内存空 间,它和它引用的变量共用同一块内存空间。(引用类型必须和引用实体是同种类型的),如&#x…

MySQL--联合索引应用细节应用规范

目录 一、索引覆盖 1.完全覆盖 2.部分覆盖 3.不覆盖索引-where条件不包含联合索引的最左则不覆盖 二、MySQL8.0在索引中的新特性 1.不可见索引 2.倒序索引 三、索引自优化--索引的索引 四、Change Buffer 五、优化器算法 1.查询优化器算法 2.设置算法 3.索引下推 …

2024年NGFW防火墙安全基准-防火墙安全功效竞争性评估实验室总结报告

Check Point 委托 Miercom 对 Check Point 下一代防火墙 (NGFW) 开展竞争性安全有效性测试, 选择的竞品分别来自 Cisco、Fortinet 和 Palo Alto Networks。对 Zscaler 的测试涉及他们的 SWG(安全网关)。测试内容包括验证防病毒、反恶意软件、…

SpringBoot+Vue开发记录(六)-- 后端配置mybatis

原型图什么的就先不管,后面再写。 本篇文章的主要内容就是springboot通过mybatis操作数据库实现增删改查。 重点是mybatis配置与相关文件数据,以后开新项目忘记了怎么配置的话可以再照着这个搞。 这算是最基础的部分了吧。 文章目录 一,配置…

基于STM32的自动宠物喂食器的Proteus仿真

文章目录 一、宠物喂食器1.题目要求2.思路2.1 OLED显示汉字2.2 DS1302模块2.3 液位传感器2.4 压力传感器和步进电机驱动 3.仿真图3.1 未仿真时3.2 开始仿真,OLED初始界面显示实时时间3.3 通过设置按键进入模式选择和喂食时间设置3.4 进入喂食时间设置3.5 设置好喂食…

计算机毕业设计Python+Spark+PyTroch游戏推荐系统 游戏可视化 游戏爬虫 神经网络混合CF推荐算法 协同过滤推荐算法 steam 大数据

毕业设计(论文) 基于SpringBoot的游戏防沉迷系统的设计与实现 摘 要 随着网络游戏市场的持续火爆,其最明显的负面影响----“网络游戏沉迷问题”已成为当前社会普遍关心的热点问题。根据2010年8月1日实施的《网络游戏管理暂行办法》,网络游…

图书管理系统——Java版

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏:JavaSE 顺序表的学习,点我 目录 图书管理系统菜单 基本框架: 书: 书架: 用户&#xff…

数字化转型必备:营销策划流程图,打造你的数字市场地图

制作营销策划流程图是一个系统化的过程,它可以帮助你清晰地规划和展示营销活动的各个阶段。 以下是制作营销策划流程图的步骤: 1.确定营销目标: 明确你的营销活动旨在实现的具体目标,比如提升品牌知名度、增加销售额、吸引新客…

Java进阶学习笔记25——Objects类

为啥比较两个对象是否相等,要用Objects的equals方法,而不是用对象自己的equals方法来解决呢? Objects: Objects类是一个工具类,提供了很多操作对象的静态方法供我们使用。 package cn.ensource.d14_objects;import ja…

Hadoop概览以及编译hadoop说明

一、Hadoop概述 Hadoop 是一个用于跨计算机集群存储和处理大型数据集的软件框架。它旨在处理大数据,即传统数据库无法有效管理的极其庞大和复杂的数据集。Hadoop不是传统意义上的数据仓库,因为它们的用途不同,架构也不同。Hadoop 是一个跨分布…

Vue2基础及其进阶面试(二)

vue2的生命周期 删除一些没用的 App.vue 删成这个样子就行 <template><router-view/></template><style lang"scss"></style>来到路由把没用的删除 import Vue from vue import VueRouter from vue-router import HomeView from .…

JAVASE之类和对象(2)

哪怕犯错&#xff0c;也不能什么都不做。 主页&#xff1a;趋早–Step 专栏&#xff1a;JAVASE gitte:https://gitee.com/good-thg 接上部分&#xff0c;我们继续来学习JAVAEE类和对象。 引言&#xff1a; 这篇文章接上一篇&#xff0c;后半部分&#xff0c;结束类和对象 目录 …

Spring Boot 3.0:未来企业应用开发的基石

文章目录 一、Spring Boot 3.0的核心特性二、Spring Boot 3.0的优势三、如何在项目中应用Spring Boot 3.01.更新项目依赖2.调整代码结构3.测试和部署 《学习Spring Boot 3.0》内容简介作者简介目录内容介绍 随着技术的飞速发展&#xff0c;企业应用开发的需求也在不断演变。Spr…

爽!AI手绘变插画,接单赚爆了!

我最近发现一款名叫Hyper-SD15-Scribble的AI项目&#xff0c;可以实现一键手绘变插画的功能&#xff0c;而且它搭载了字节出品的超快速生成图片的AI大模型Hyper-SD15&#xff0c;可以实现几乎实时生成图片&#xff0c;有了它&#xff0c;拿去接一些手绘商单分分钟出图&#xff…

跟TED演讲学英文:How to escape education‘s death valley by Sir Ken Robinson

How to escape education’s death valley Link: https://www.ted.com/talks/sir_ken_robinson_how_to_escape_education_s_death_valley Speaker: Sir Ken Robinson Date: April 2013 文章目录 How to escape educations death valleyIntroductionVocabularySummaryTranscri…

WPF学习日常篇(一)--开发界面视图布局

接下来开始日常篇&#xff0c;我在主线篇&#xff08;正文&#xff09;中说过要介绍一下我的界面排布&#xff0c;科学的排布才更科学更有效率的进行敲代码和开发。日常篇中主要记录我的一些小想法和所考虑的一些细节。 一、主界面设置 主界面分为左右两部分&#xff0c;分为…

有什么免费视频翻译软件?安利5款视频翻译软件给你

随着“跨文化交流”话题的热度不断攀升&#xff0c;越来越多的视频内容跨越国界&#xff0c;触及全球观众。 在这一趋势下&#xff0c;视频翻译行业迎来了巨大的发展机遇。然而&#xff0c;面对众多的视频翻译工具&#xff0c;如何挑选出最合心意的那款呢&#xff1f; 现在&a…

【C++】从零开始构建红黑树

送给大家一句话&#xff1a; 日子没劲&#xff0c;就过得特别慢&#xff0c;但凡有那么一点劲&#xff0c;就哗哗的跟瀑布似的拦不住。 – 巫哲 《撒野》 &#x1f30b;&#x1f30b;&#x1f30b;&#x1f30b;&#x1f30b;&#x1f30b;&#x1f30b;&#x1f30b; ⛰️⛰️…