AMP 混合精度训练中的动态缩放机制: grad_scaler.py函数解析( torch._amp_update_scale_)

AMP 混合精度训练中的动态缩放机制

在深度学习中,混合精度训练(AMP, Automatic Mixed Precision)是一种常用的技术,它利用半精度浮点(FP16)计算来加速训练,同时使用单精度浮点(FP32)来保持数值稳定性。为了在混合精度训练中避免数值溢出,PyTorch 提供了一种动态缩放机制来调整 “loss scale”(损失缩放值)。本文将详细解析动态缩放机制的实现原理,并通过代码展示其内部逻辑。


动态缩放机制简介

动态缩放机制的核心思想是通过一个可动态调整的缩放因子(scale factor)放大 FP16 的梯度,从而降低舍入误差对训练的影响。当检测到数值不稳定(例如 NaN 或无穷大)时,缩放因子会被降低;当连续多步未检测到数值问题时,缩放因子会被提高。其调整策略基于以下两个参数:

  • growth_factor: 连续成功步骤后用于增加缩放因子的乘数(通常大于 1,如 2.0)。
  • backoff_factor: 检测到数值溢出时用于减少缩放因子的乘数(通常小于 1,如 0.5)。

此外,动态缩放还使用 growth_interval 参数控制连续成功步骤的计数阈值。当达到这个阈值时,缩放因子才会增加。


AMP 缩放更新核心代码解析

PyTorch 实现了一个用于更新缩放因子的 CUDA 核函数以及相关的 Python 包装函数。以下是核心代码解析:

CUDA 核函数实现

// amp_update_scale_cuda_kernel 核函数实现
__global__ void amp_update_scale_cuda_kernel(float* current_scale,int* growth_tracker,const float* found_inf,double growth_factor,double backoff_factor,int growth_interval) {if (*found_inf) {// 如果发现梯度中存在 NaN 或 Inf,缩放因子乘以 backoff_factor,并重置 growth_tracker。*current_scale = (*current_scale) * backoff_factor;*growth_tracker = 0;} else {// 未发现数值问题,增加 growth_tracker 的计数。auto successful = (*growth_tracker) + 1;if (successful == growth_interval) {// 当 growth_tracker 达到 growth_interval,尝试增长缩放因子。auto new_scale = static_cast<float>((*current_scale) * growth_factor);if (isfinite_ensure_cuda_math(new_scale)) {*current_scale = new_scale;}*growth_tracker = 0;} else {*growth_tracker = successful;}}
}
核函数逻辑
  1. 发现数值溢出(found_inf > 0):

    • 缩放因子 current_scale 乘以 backoff_factor
    • 重置成功计数器 growth_tracker 为 0。
  2. 未发现数值溢出:

    • 增加成功计数器 growth_tracker
    • 如果 growth_tracker 达到 growth_interval,则将缩放因子乘以 growth_factor
    • 保证缩放因子不会超过 FP32 的数值上限。

C++ 包装函数实现

在 PyTorch 中,这一 CUDA 核函数通过 C++ 包装函数 _amp_update_scale_cuda_ 被调用。以下是实现代码:

Tensor& _amp_update_scale_cuda_(Tensor& current_scale,Tensor& growth_tracker,const Tensor& found_inf,double growth_factor,double backoff_factor,int64_t growth_interval) {TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");// 核函数调用amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(current_scale.mutable_data_ptr<float>(),growth_tracker.mutable_data_ptr<int>(),found_inf.const_data_ptr<float>(),growth_factor,backoff_factor,growth_interval);C10_CUDA_KERNEL_LAUNCH_CHECK();return current_scale;
}

Python 调用入口

AMP 的 GradScaler 类通过 _amp_update_scale_ 函数更新缩放因子,以下是相关代码:
代码来源:anaconda3/envs/xxxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

具体调用过程可以参考笔者的另一篇博文:PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:"""更新缩放因子"""if not self._enabled:return_scale, _growth_tracker = self._check_scale_growth_tracker("update")if new_scale is not None:# 设置用户定义的新缩放因子。self._scale.fill_(new_scale)else:# 收集所有优化器中的 found_inf 数据。found_infs = [found_inf.to(device=_scale.device, non_blocking=True)for state in self._per_optimizer_states.values()for found_inf in state["found_inf_per_device"].values()]found_inf_combined = found_infs[0]if len(found_infs) > 1:for i in range(1, len(found_infs)):found_inf_combined += found_infs[i]# 更新缩放因子。torch._amp_update_scale_(_scale,_growth_tracker,found_inf_combined,self._growth_factor,self._backoff_factor,self._growth_interval,)

总结

PyTorch 的动态缩放机制通过 CUDA 核函数和 Python 包装函数协作完成。其核心逻辑是:

  1. 检测数值不稳定(如 NaN 或 Inf),通过缩小缩放因子提高数值稳定性。
  2. 当连续多次未出现数值不稳定时,逐步增大缩放因子以充分利用 FP16 的动态范围。
  3. 所有更新操作都在 GPU 上异步完成,最大限度地减少同步开销。

通过动态调整缩放因子,AMP 有效地加速了深度学习模型的训练,同时避免了梯度溢出等数值问题。


推荐阅读

  • PyTorch 官方文档
  • 混合精度训练介绍

后记

2025年1月2日15点38分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/66404.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

细说STM32F407单片机FSMC连接外部SRAM的方法及HAL驱动

目录 一、FSMC连接外部SRAM的原理 1、 FSMC控制区域的划分 2、SRAM芯片与MCU的连接 二、访问外部SRAM的HAL驱动程序 1、外部SRAM初始化与控制 2、外部SRAM读写函数 3、直接通过指针访问外部SRAM 4、DMA方式读写外部SRAM 本文介绍STM32F407单片机FSMC连接外部SRAM及以轮…

Gitee图形界面上传(详细步骤)

目录 1.软件安装 2.安装顺序 3.创建仓库 4.克隆远程仓库到本地电脑 提交代码的三板斧 1.软件安装 Git - Downloads (git-scm.com) Download – TortoiseGit – Windows Shell Interface to Git 2.安装顺序 1. 首先安装git-2.33.1-64-bit.exe&#xff0c;顺序不能搞错2. …

基于单片机洗衣机控制器的设计(论文+源码)

1需求分析 在智能洗衣机系统设计中&#xff0c;考虑到洗衣机在实际应用过程中&#xff0c;需要满足用户对于不同衣物清洁、消毒的应用要求&#xff0c;对设计功能进行分析&#xff0c;具体如下&#xff1a; 通过按键实现洗衣机不同工作模式的切换&#xff0c;包括标准模式&…

【学Rust开发CAD】2 创建第一个工作空间、项目及库

文章目录 一、 创建工作空间二、新建项目&#xff08;可执行文件&#xff09;三、 新建库&#xff08;库文件&#xff09;四、更新项目依赖五、编写代码七、总结 在 Rust 中&#xff0c;工作空间&#xff08;workspace&#xff09;允许你管理多个相关的包&#xff08;crate&…

STM32的LED点亮教程:使用HAL库与Proteus仿真

学习目标&#xff1a;掌握使用STM32 HAL库点亮LED灯&#xff0c;并通过Proteus进行仿真验证&#xff01; 建立HAL库标准工程 1.新建工程文件夹 新建工程文件夹建议路径尽量为中文。建立文件夹的目的为了更好分类去管理项目工程中需要的各类工程文件。 首先需要在某个位置建立工…

Unity Excel转Json编辑器工具

功能说明&#xff1a;根据 .xlsx 文件生成对应的 JSON 文件&#xff0c;并自动创建脚本 注意事项 Excel 读取依赖 本功能依赖 EPPlus 库&#xff0c;只能读取 .xlsx 文件。请确保将该脚本放置在 Assets 目录下的 Editor 文件夹中。同时&#xff0c;在 Editor 下再创建一个 Exc…

牛客网刷题 ——C语言初阶(6指针)——字符逆序

1. 题目描述&#xff1a;字符逆序 牛客网题目链接 将一个字符串str的内容颠倒过来&#xff0c;并输出。 输入描述: 输入一个字符串&#xff0c;可以有空格 输出描述: 输出逆序的字符串 示例1 输入 I am a student 输出 tneduts a ma I 2. 思路 首先字符串逆序&#xff0c;之…

【USRP】教程:在Macos M1(Apple芯片)上安装UHD驱动(最正确的安装方法)

Apple芯片 前言安装Homebrew安装uhd安装gnuradio使用b200mini安装好的路径下载固件后续启动频谱仪功能启动 gnu radio关于博主 前言 请参考本文进行安装&#xff0c;好多人买了Apple芯片的电脑&#xff0c;这种情况下&#xff0c;可以使用UHD吗&#xff1f;答案是肯定的&#…

141.《mac m系列芯片安装mongodb详细教程》

文章目录 下载从官网下载安装包 下载后双击解压出文件夹安装文件名修改为 mongodb配置data存放位置和日志log的存放位置启动方式一方式二方式二:输入mongo报错以及解决办法 本人电脑 m2 pro,属于 arm 架构 下载 官网地址: mongodb官网 怎么查看自己电脑应该下载哪个版本,输入…

Elasticsearch:基础概念

这里写目录标题 一、什么是Elasticsearch1、基础介绍2、什么是全文检索3、倒排索引4、索引&#xff08;1&#xff09;创建索引a 创建索引基本语法b 只定义索引名&#xff0c;setting、mapping取默认值c 创建一个名为student_index的索引&#xff0c;并设置一些自定义字段 &…

Dexcap复现代码数据预处理全流程(四)——demo_clipping_3d.py

此脚本的主要功能是可视化点云数据文件&#xff08;.pcd 文件&#xff09;&#xff0c;并通过键盘交互选择演示数据的起始帧和结束帧&#xff0c;生成片段标记文件 (clip_marks.json) 主要流程包括&#xff1a; 用户指定数据目录&#xff1a;检查目录是否存在并处理标记文件 -…

安装Cockpit服务,使用Web页面管理你的Linux服务器

说起管理 Linux 服务器&#xff0c;大家首先想到的使用 SecureCRT、Xshell、MobaXterm 等工具远程到服务器&#xff0c;然后使用命令行管理服务器。今天给大家介绍一个好玩的工具&#xff0c;名字叫Cockpit&#xff0c; Cockpit 是一个免费开源的基于 web 的 Linux 服务器管理…

[A-25]ARMv8/v9-GIC的系统架构(中断的硬件基础)

ver0.1 前言 我们在观看很多的影视剧过程中,尤其是军旅体裁类型的布景中,经常会看见高级干部的办公桌上都会有几部电话机。这样的电话可不能小看,重要的事情尤其是突发和紧急的情况都要通过这几部电话第一时间通知给决策者。这几部电话,必须举报几个特点:及时性好、稳定…

13-线段的转折点样式

13-线段的转折点样式_哔哩哔哩_bilibili13-线段的转折点样式是一次性学会 Canvas 动画绘图&#xff08;核心精讲50个案例&#xff09;2023最新教程的第14集视频&#xff0c;该合集共计53集&#xff0c;视频收藏或关注UP主&#xff0c;及时了解更多相关视频内容。https://www.bi…

计算机网络 (28)虚拟专用网VPN

前言 虚拟专用网络&#xff08;VPN&#xff09;是一种在公共网络上建立私有网络连接的技术&#xff0c;它允许远程用户通过加密通道访问内部网络资源&#xff0c;实现远程办公和安全通信。 一、基本概念 定义&#xff1a;VPN是一种通过公共网络&#xff08;如互联网&#xff09…

基于transformer的目标检测:DETR

目录 一、背景介绍 二、DETR的工作流程 三、DETR的架构 1. 损失函数 2. 网络框架讲解及举例 一、背景介绍 在深度学习和计算机视觉领域&#xff0c;目标检测一直是一个核心问题。传统方法依赖于复杂的流程和手工设计的组件&#xff0c;如非极大值抑制&#xff08;nms&…

Vue Amazing UI 组件库(Vue3+TypeScript+Vite 等最新技术栈开发)

Vue Amazing UI 一个 Vue 3 组件库 使用 TypeScript&#xff0c;都是单文件组件 (SFC)&#xff0c;支持 tree shaking 有点意思 English | 中文 Vue Amazing UI 是一个基于 Vue 3、TypeScript、Vite 等最新技术栈开发构建的现代化组件库&#xff0c;包含丰富的 UI 组件和常…

C语言----指针

目录 1.概念 2.格式 3.指针操作符 4.初始化 1. 将普通变量的地址赋值给指针变量 a. 将数组的首地址赋值给指针变量 b. 将指针变量里面保存的地址赋值给另一个指针变量 5.指针运算 5.1算术运算 5.2 关系运算 指针的大小 总结&#xff1a; 段错误 指针修饰 1. con…

Python应用——将Matplotlib图形嵌入Tkinter窗口

Python应用——将Matplotlib图形嵌入Tkinter窗口 目录 Python应用——将Matplotlib图形嵌入Tkinter窗口1 模块简介2 示例代码2.1 Matplotlib嵌入Tkinter2.2 Matplotlib嵌入Tkinter并显示工具栏 1 模块简介 Tkinter是Python的标准GUI&#xff08;图形用户界面&#xff09;库&…

【linux基础I/O(2)】理解文件系统|文件缓冲区|软硬链接|动静态库

目录 前言1. 理解C语言的缓冲区2. 对文件系统的初认识3. 理解软硬链接1. 软硬链接的特征2.软硬链接的作用 4. 理解动静态库5. 总结 前言 对于文件来讲,有打开的在内存中的文件,也有没有打开的在磁盘上文件,上一篇文章讲解的是前者,本篇文章将带大家了解后者! 本章重点: 本篇文…