混合精度训练(MAP)

一、介绍

使用精度低于32位浮点数的数字格式有很多好处。首先,它们需要更少的内存,可以训练和部署更大的神经网络。其次,它们需要更少的内存带宽,这加快了数据传输操作。第三,数学运算在降低精度的情况下运行得更快,特别是在支持Tensor Core的gpu上。混合精确训练实现了所有这些好处,同时确保与完全精确训练相比,没有任务特定的准确性损失。它通过识别需要完全精度的步骤,并仅对这些步骤使用32位浮点数,而在其他地方使用16位浮点数来实现这一点。

在大模型训练场景中,最占用显存的是中间激活值,而混合精度训练方法是采用半精度保存,显存空间直接减半而且还能加速计算; 中间激活值占用显存的直观感觉如下:
在这里插入图片描述

二、混合精度训练

混合精度训练以半精度格式执行操作,同时以单精度存储最小的信息,以尽可能多地保留网络关键部分的信息,从而显著提高了计算速度。自从在Volta和Turing架构中引入Tensor Cores以来,通过切换到混合精度,可以体验到显著的训练速度提升——在大多数算术密集的模型架构上,总体速度提升了3倍。使用混合精度训练需要两个步骤:

1. 移植模型以在适当的地方使用FP16数据类型。
2. 添加损失缩放以保持小的梯度值。

在Pascal架构中引入了以较低精度训练深度学习网络的能力,并在CUDA 8的NVIDIA深度学习SDK中首次得到支持。

混合精度是指在计算方法中组合使用不同的数值精度。

与更高精度的 FP32 相比,半精度(也称为 FP16)数据与 FP64 相比减少了神经网络的内存使用,允许训练和部署更大的网络,并且 FP16 数据传输比 FP32 或 FP64 传输花费的时间更少。

单精度(也称为 32 位)是一种常见的浮点格式( float 在 C 派生的编程语言中),而 64 位则称为双精度 ( double )。深度神经网络 (DNN) 在许多领域取得了突破,包括:

  • 图像处理和理解
  • 语言建模
  • 语言翻译
  • 语音处理
  • 玩游戏等等

为了实现这些结果,DNN 的复杂性一直在增加,这反过来又增加了训练这些网络所需的计算资源。降低所需资源的一种方法是使用精度较低的算术,它具有以下优点:

减少所需的内存量
精度浮点格式 (FP16) 使用 16 位,而单精度 (FP32) 使用 32 位。降低所需的内存可以训练更大的模型或使用更大的小批量进行训练。

缩短训练或推理时间
执行时间可能对内存或算术带宽敏感。半精度将访问的字节数减半,从而减少了在内存受限层中花费的时间。与单精度相比,NVIDIA GPU 的半精度算术吞吐量提高了 8 倍,从而加快了数学受限层的速度。

图 1.bigLSTM 英语语言模型的训练曲线显示了混合精度训练技术的好处。Y 轴是训练损失。不带损耗缩放的混合精度(灰色)在一段时间后会发散,而带损耗缩放的混合精度(绿色)与单精度模型(黑色)匹配
在这里插入图片描述
由于 DNN 训练传统上依赖于 IEEE 单精度格式,因此本指南将重点介绍如何以半精度进行训练,同时保持以单精度实现的网络精度(如图 1 所示)。这种技术称为混合精度训练,因为它同时使用单精度和半精度表示。

2.1 半精度格式

IEEE 754 标准定义了以下 16 位半精度浮点格式:1 个符号位、5 个指数位和 10 个小数位。

2.2 混合训练的流程

2.2.1 拷贝一份FP32的权重
2.2.2 用较大的值初始化缩放因子S.
2.2.3 进入迭代中:

  • a 生成一份FP16的权重
  • b. 前向传递(FP16权重与中间值)
  • c.计算的loss乘以缩放因子S.
  • d. 反向传递 (FP16权重, 中间激活值, 梯度)
  • e. 如果权重梯度中存在无穷大(Inf)或不是一个数字(NaN):
    1. 减小S的值
    2. 跳过权重更新,并进行下一次迭代。
  • f.将权重梯度乘以1/S
  • g.完成权重更新(包括梯度裁剪等操作)
  • h. 如果在最近的N次迭代中没有出现无穷大或不是一个数字的情况,则增加S的值。

在这里插入图片描述

三、混合精度相关问题

  1. 抓住主要矛盾,目的是减少中间激活的显存占用
  2. 在网络训练的后期,梯度值会变得非常小,缩放loss计算得到梯度后,可以用fp32存储,然后进行unscale,避免学习率*unscale *fp16梯度下溢,流程如下(最好是配上scale因子就更完美了,如果不加scale,会存在fp16的gradients存在下溢的可能):

四、PyTorch实现

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()

https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240266.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv5算法改进(23)— 更换主干网络GhostNet + 添加CA注意力机制 + 引入GhostConv

前言:Hello大家好,我是小哥谈。本节课就让我们结合论文来对YOLOv5进行组合改进(更换主干网络GhostNet + 添加CA注意力机制 + 引入GhostConv),希望同学们学完本节课可以有所启迪,并且后期可以自行进行YOLOv5算法的改进!🌈 前期回顾: YOLOv5算法改进(1)— 如何去…

C++类与对象(中)第一篇

目录 前言: 类的六个默认成员函数 构造函数 析构函数 拷贝构造函数 拷贝场景一:函数参数类型为类类型对象 拷贝场景二:利用已存在的对象创建新对象 拷贝场景三:函数返回值类型为类类型对象 前言: 编译器编译类…

推箱子地图库1-49关

推箱子地图库1-49关 49关 local WALL1--{"墙","墙 "}4 10287 local DEST2--{"目的地",""}1 4001100 10157 local BOX3--{"箱子","¥"} 2 2000801 local PLAYER4--{"玩家","&&a…

influxdb-cluster集群部署

一.部署环境 * InfluxDB集群节点数:mate服务至少3个节点,节点数越多,集群性能越高。 * 操作系统:支持的操作系统包括Linux、Windows和MacOS。 * CPU:至少2核4线程,主频越高越好。 * 内存:至少8…

【已解决】Redis序列化反序列化不一致 - String类型值多了双引号问题

在项目中使用spring 的RedisTemplate从redis中获取数据的时候,发现字符串的value多了双引号。如下图所示: 产生的原因可以分一下几个方面: 一、采用的序列化对象不同 多服务之间调用时候,序列化服务A(向redis中写数据的)和反序…

【翼韵】数据上传沟通、决策、试错

韵达德邦来说说,翼达、翼韵、翼德、翼邦的小记录 翼达同学:沟通成本好大! 翼韵同学:决策成本很大! 翼德同学:试错成本更大! 翼邦同学:你们加起最大! QY成本沟通成本33%决…

Win7如何修改MAC地址

MAC地址,又叫做物理地址、硬件地址,是用来定义网络设备的位置,一般情况下,MAC地址在网卡中是固定的,但不排除有人手动去修改自己的MAC地址。win7如何修改MAC地址?其实修改MAC地址的方法很简单,可以通过硬件…

K8s出现问题时,如何排查解决!

K8s问题的排查 1. POD启动异常、部分节点无法启动pod2. 审视集群状态3. 追踪事件日志4. 聚焦Pod状态5. 检查网络连通性6. 审视存储配置7. 研究容器日志8. K8S集群网络通信9. 问题:Service 是否通过 DNS 工作?10. 总结1、POD启动异常、部分节点无法启动p…

普通Java项目打包可执行Jar

普通Java项目打包 IDEA配置 在项目配置中选择 Artifacts -> JAR -> From modules with dependencies 选择项目模块,程序主类、依赖引入方式、清单文件位置 确认Jar名称和Jar输出目录 通过 Build -> Build Artifact -> Build 打包Jar文件 Java打包可执…

JavaWeb笔记之SVN

一、版本控制 软件开发过程中 变更的管理; 每天的新内容;需要记录一下; 版本分支;整合到一起; 主要的功能对于文件变更的追踪; 多人协同开发的情况下,更好的管理我们的软件。 大型的项目;一个团队来进行开发; 1: 代码的整合 2: 代…

2023-强网杯-【强网先锋-ez_fmt】

文章目录 ez_fmt libc-2.31.so检查main思路exp 参考链接 ez_fmt libc-2.31.so 检查 没有地址随机化 main 简单粗暴的printf格式化字符串漏洞 思路 泄露地址,覆盖返回地址形成ROP链 printf执行时栈上存在__libc_start_main243的指令的地址,可以泄露…

C++哈希表的实现

C哈希表的实现 一.unordered系列容器的介绍二.哈希介绍1.哈希概念2.哈希函数的常见设计3.哈希冲突4.哈希函数的设计原则 三.解决哈希冲突1.闭散列(开放定址法)1.线性探测1.动图演示2.注意事项3.代码的注意事项4.代码实现 2.开散列(哈希桶,拉链法)1.概念2.动图演示3.增容问题1.拉…

MyBatis 架构分析

文章目录 三层架构一、基础支撑层1.1 类型转换模块1.2 日志模块1.3 反射工具模块1.4 Binding 模块1.5 数据源模块1.6 缓存模块1.6 解析器模块1.7 事务管理模块 二、核心处理层2.1 配置解析2.2 SQL 解析与 scripting 模块。2.3 MyBatis 中的 scripting 模块就是负责动态生成 SQL…

SpringCloud Alibaba(itheima)

SpringCloud Alibaba 第一章 微服务介绍1.1系统架构演变1.1.1单体应用架构1.1.2垂直应用架构1.1.3分布式架构1.1.4 SOA架构1.1.5微服务架构 1.2微服务架构介绍1.2.1微服务架构的常见问题1.2.2微服务架构的常见概念1.2.3微服务架构的常见解决方案 1.3 SpringCloud Alibaba介绍1.…

用23种设计模式打造一个cocos creator的游戏框架----(二十二)原型模式

1、模式标准 模式名称:原型模式 模式分类:创建型 模式意图:用原型实例指定创建对象的种类,并且通过复制这些原型创建新的对象 结构图: 适用于: 1、当一个系统应该独立于它的产品创建、构成和表示时 2、…

BUUCTF-Crypto合集-WP

获取CTF工具可关注CSJH网络安全团队,回复CTF工具 一眼就解密 下面的字符串解密后便能获得flag:ZmxhZ3tUSEVfRkxBR19PRl9USElTX1NUUklOR30 注意:得到的 flag 请包上 flag{} 提交 大小写字母加数字,而且等于号结尾,bas…

实在智能斩获钛媒体2023全球创新评选科技类「 大模型创新应用奖」

近日,历时三天的钛媒体2023 T-EDGE全球创新大会以“新视野新链接”为主题在北京隆重举办。作为科创领域全新高度的年度盛事,大会吸引了AI各产业链近百位海内外创投人、尖端企业家、商业领袖和国际嘉宾齐聚一堂,围绕新一轮AI革命、智慧数字化、…

Java中使用JTS实现WKB数据写入、转换字符串、读取

场景 Java中使用JTS实现WKT字符串读取转换线、查找LineString的list中距离最近的线、LineString做缓冲区扩展并计算点在缓冲区内的方位角: Java中使用JTS实现WKT字符串读取转换线、查找LineString的list中距离最近的线、LineString做缓冲区扩展并计算点在缓冲区内…

从Maven初级到高级

一.Maven简介 Maven 是 Apache 软件基金会组织维护的一款专门为 Java 项目提供构建和依赖管理支持的工具。 一个 Maven 工程有约定的目录结构,约定的目录结构对于 Maven 实现自动化构建而言是必不可少的一环,就拿自动编译来说,Maven 必须 能…

python调用DALL·E绘画

实现用gpt的api和他对话后,我们试着调用DALLE的api进行绘画 参考文档 OpenAI API 运行代码 from openai import OpenAIclient OpenAI()user_prompt input("请输入您想生成的图片描述: ")response client.images.generate(model"dall-e-3"…