PyTorch 自动混合精度AMP Grad Scaler 源码解析:_unscale_grads_ 与 unscale_ 函数

PyTorch AMP Grad Scaler 源码解析:_unscale_grads_ 与 unscale_ 函数

引言

本文详细解析 PyTorch 自动混合精度(AMP)模块中 grad_scaler.py 文件的两个关键函数:_unscale_grads_unscale_。这些函数在梯度缩放与反缩放过程中起到了关键作用,特别适用于训练大规模深度学习模型时的数值稳定性优化。我们还将给出详细的示例与数值模拟,帮助理解其具体应用。


1. _unscale_grads_ 函数解析

def _unscale_grads_(self,optimizer: torch.optim.Optimizer,inv_scale: torch.Tensor,found_inf: torch.Tensor,allow_fp16: bool,) -> Dict[torch.device, torch.Tensor]:per_device_inv_scale = _MultiDeviceReplicator(inv_scale)per_device_found_inf = _MultiDeviceReplicator(found_inf)# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.# There could be hundreds of grads, so we'd like to iterate through them just once.# However, we don't know their devices or dtypes in advance.# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict# Google says mypy struggles with defaultdicts type annotations.per_device_and_dtype_grads: Dict[torch.device, Dict[torch.dtype, List[torch.Tensor]]] = defaultdict(lambda: defaultdict(list))with torch.no_grad():for group in optimizer.param_groups:for param in group["params"]:assert isinstance(param, torch.Tensor)if param.grad is None:continueif (not allow_fp16) and param.grad.dtype == torch.float16:raise ValueError("Attempting to unscale FP16 gradients.")if param.grad.is_sparse:# is_coalesced() == False means the sparse grad has values with duplicate indices.# coalesce() deduplicates indices and adds all values that have the same index.# For scaled fp16 values, there's a good chance coalescing will cause overflow,# so we should check the coalesced _values().if param.grad.dtype is torch.float16:param.grad = param.grad.coalesce()to_unscale = param.grad._values()else:to_unscale = param.grad# TODO: is there a way to split by device and dtype without appending in the inner loop?per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)for device, per_dtype_grads in per_device_and_dtype_grads.items():for grads in per_dtype_grads.values():torch._amp_foreach_non_finite_check_and_unscale_(grads,per_device_found_inf.get(device),per_device_inv_scale.get(device),)return per_device_found_inf._per_device_tensors

1.1 函数定义

def _unscale_grads_(self,optimizer: torch.optim.Optimizer,inv_scale: torch.Tensor,found_inf: torch.Tensor,allow_fp16: bool,) -> Dict[torch.device, torch.Tensor]:

该函数主要用于将梯度从缩放状态恢复到原始大小,同时检查是否存在数值溢出情况。

1.2 参数说明

  • optimizer:优化器对象,包含训练过程中使用的所有参数。
  • inv_scale:缩放因子的倒数,用于恢复梯度。
  • found_inf:用于记录是否存在无穷大或 NaN 值。
  • allow_fp16:是否允许 FP16 精度的梯度反缩放,默认设置为 False。

1.3 核心实现步骤

  1. 按设备与数据类型分类梯度:

    • 将优化器中的参数按设备和数据类型进行分组,便于批量处理。
    • 使用 defaultdict 对分组存储。
  2. 检查梯度并分类:

    • 遍历每个参数,如果存在稀疏梯度,使用 coalesce() 消除重复索引。关于这个方法, 可以参考笔者的另一篇博客:PyTorch 中 coalesce() 函数详解与应用示例
    • 将梯度分组存储到 per_device_and_dtype_grads 中。
  3. 调用 PyTorch 内部函数反缩放梯度:

    • 使用 torch._amp_foreach_non_finite_check_and_unscale_() 批量反缩放梯度并检查是否存在 NaN 或无穷大值。 这个具体解析请参考笔者的另一篇博客:PyTorch源码_amp_foreach_non_finite_check_and_unscale_cpu_kernel 函数解析:自动混合精度AMP的一部分
  4. 返回各设备上的溢出检查结果:

    • 输出包含各设备是否发现溢出的布尔值张量。

1.4 关键代码片段

with torch.no_grad():for group in optimizer.param_groups:for param in group["params"]:if param.grad is None:continueif (not allow_fp16) and param.grad.dtype == torch.float16:raise ValueError("Attempting to unscale FP16 gradients.")to_unscale = param.grad._values() if param.grad.is_sparse else param.gradper_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)for device, per_dtype_grads in per_device_and_dtype_grads.items():for grads in per_dtype_grads.values():torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), per_device_inv_scale.get(device))

2. unscale_ 函数解析

def unscale_(self, optimizer: torch.optim.Optimizer) -> None:"""Divides ("unscales") the optimizer's gradient tensors by the scale factor.:meth:`unscale_` is optional, serving cases where you need to:ref:`modify or inspect gradients<working-with-unscaled-gradients>`between the backward pass(es) and :meth:`step`.If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::...scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)scaler.step(optimizer)scaler.update()Args:optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled... note:::meth:`unscale_` does not incur a CPU-GPU sync... warning:::meth:`unscale_` should only be called once per optimizer per :meth:`step` call,and only after all gradients for that optimizer's assigned parameters have been accumulated.Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError... warning:::meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute."""if not self._enabled:returnself._check_scale_growth_tracker("unscale_")optimizer_state = self._per_optimizer_states[id(optimizer)]if optimizer_state["stage"] is OptState.UNSCALED:raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")elif optimizer_state["stage"] is OptState.STEPPED:raise RuntimeError("unscale_() is being called after step().")# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.assert self._scale is not Noneinv_scale = self._scale.double().reciprocal().float()found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)optimizer_state["stage"] = OptState.UNSCALED

2.1 函数定义

def unscale_(self, optimizer: torch.optim.Optimizer) -> None:

该函数是 PyTorch AMP 提供的外部接口,供用户调用以解除梯度缩放。

2.2 参数说明

  • optimizer:包含所有待训练参数的优化器对象。

2.3 核心实现步骤

  1. 状态检查:
    • 检查是否已经调用过 unscale_step
  2. 计算反缩放因子:
    • 使用 FP64 精度计算缩放因子的倒数,以避免精度误差。reciprocal这是取倒数的函数,具体可以参考笔者的另一篇博客:PyTorch 中 reciprocal(取倒数)函数的深入解析:分析底层实现CPP代码
  3. 调用内部函数 _unscale_grads_
    • 执行反缩放过程,包含稀疏梯度与 NaN 检查。
  4. 更新状态记录:
    • 将优化器状态更新为 “UNSCALED”。

2.4 关键代码片段

if optimizer_state["stage"] is OptState.UNSCALED:raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED

3. 使用示例与数值模拟

3.1 示例代码

import torch
from torch.cuda.amp import GradScaler, autocast# 创建模型和优化器
model = torch.nn.Linear(10, 1).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
scaler = GradScaler()# 模拟训练循环
for epoch in range(2):for step in range(5):data = torch.randn(16, 10).cuda()target = torch.randn(16, 1).cuda()optimizer.zero_grad()# 使用混合精度训练with autocast():output = model(data)loss = torch.nn.functional.mse_loss(output, target)# 缩放梯度scaler.scale(loss).backward()# 手动解除梯度缩放scaler.unscale_(optimizer)# 使用梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# 更新权重与缩放器scaler.step(optimizer)scaler.update()print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")

3.2 数值模拟分析

  1. 梯度缩放影响:
    缩放因子 = 65536 时,梯度放大至 10^4 量级,有助于 FP16 避免下溢问题。
  2. 反缩放结果验证:
    对比反缩放前后的梯度值,可观察到恢复精度并避免溢出错误。
  3. 梯度裁剪测试:
    执行 torch.nn.utils.clip_grad_norm_(),确认反缩放后的梯度值能够被安全裁剪。

4. 注意事项与总结

  1. 注意 API 使用顺序:
    调用 unscale_ 应在反向传播完成后、优化器更新前进行。
  2. 防止重复调用:
    多次调用可能导致状态不一致,应确保每轮训练仅调用一次。
  3. 稀疏梯度支持:
    自动处理稀疏梯度的特殊情况,避免溢出。

这两个函数是 AMP 核心模块,提供了稳定高效的混合精度训练支持。通过示例与数值分析,开发者可以更好地理解 AMP 工作原理并优化深度学习模型训练过程。


后记

2025年1月2日18点49分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/65998.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker内外如何实现ROS通信

写在前面 在一台电脑上装有docker&#xff0c;docker内外均装有ROS系统&#xff0c;现在想要实现docker内外的ROS通信&#xff0c;怎么办呢&#xff1f; 首先&#xff0c;因为是同一台电脑的docker内外&#xff0c;所以IP本身是互通的&#xff0c;不需要在/etc/hosts中添加IP…

双指针与滑动窗口

双指针 相向双指针 两数之和 题意是找到不同两个数使得它们相加和为target&#xff0c;数组有序 利用数组有序的性质&#xff0c;判断指针前后的区间的性质 例如&#xff1a;2 3 4 6 8, target 9 2 8 10 > 9, 因为非递减序列&#xff0c;2之后的每个数都会大等于2&…

unity开发之shader 管道介质流动特效

效果 shader graph 如果出现下面的效果&#xff0c;那是因为你模型的问题&#xff0c;建模做贴图的时候没有设置好UV映射&#xff0c;只需重新设置下映射即可

python +tkinter绘制彩虹和云朵

python tkinter绘制彩虹和云朵 彩虹&#xff0c;简称虹&#xff0c;是气象中的一种光学现象&#xff0c;当太阳光照射到半空中的水滴&#xff0c;光线被折射及反射&#xff0c;在天空上形成拱形的七彩光谱&#xff0c;由外圈至内圈呈红、橙、黄、绿、蓝、靛、紫七种颜色。事实…

stable diffusion安装mov2mov

第一步&#xff1a; 下载mov2mov&#xff0c;地址&#xff1a;https://gitcode.com/gh_mirrors/sd/sd-webui-mov2mov 下载包到web-ui的sd-webui-aki-v4.10\extensions文件夹面解压 第二步&#xff1a;在文件夹中调出cmd窗口&#xff0c;执行下列命令&#xff0c; git restore…

SpringSpringBoot常用注解总结

目录 1. SpringBootApplication 2. Spring Bean 相关 2.1. Autowired 2.2. Component,Repository,Service, Controller 2.3. RestController 2.4. Scope 2.5. Configuration 3. 处理常见的 HTTP 请求类型 3.1. GET 请求 3.2. POST 请求 3.3. PUT 请求 3.4. DELETE 请…

STM32 软件I2C读写

单片机学习&#xff01; 目录 前言 一、软件I2C读写代码框架 二、I2C初始化 三、六个时序基本单元 3.1 引脚操作的封装和改名 3.2 起始条件执行逻辑 3.3 终止条件执行逻辑 3.4 发送一个字节 3.5 接收一个字节 3.5 发送应答&接收应答 3.5.1 发送应答 3.5.2 接…

七种改进爬山算法的方法

一、爬山算法 爬山算法(Hill Climbing Algorithm)是一种启发式的基于局部最优解的搜索算法,用于在给定的搜索空间中寻找全局最优解或足够好的解。它属于局部搜索算法,通常用于解决优化问题,包括连续和离散问题。 爬山算法模拟了爬山的过程,从某个随机起始点开始,不断向更…

MYSQL--------MYSQL中的运算符

以下是 MySQL 中各种运算符的介绍及代码示例&#xff1a; 算术运算符 算术运算符用于执行基本的数学运算&#xff0c;包括加、减、乘、除、取模&#xff08;取余&#xff09;。 -- 创建一个名为 operator_demo 的表 CREATE TABLE operator_demo (a INT,b INT );-- 插入示例数…

MySQL图形化界面工具--DataGrip

之前介绍了在命令行进行操作&#xff0c;但是不够直观&#xff0c;本次介绍图形化界面工具–DataGrip。 安装DataGrip 官网链接&#xff1a;官网下载链接 常规的软件安装流程。 参考链接&#xff1a;DataGrip安装 使用DataGrip 添加数据源&#xff1a; 第一次使用最下面会…

【虚拟机】VMware 16图文安装和配置 AlmaLinux OS 9.5 教程

准备工作 下载AlmaLinux ISO文件&#xff1a;从AlmaLinux官方网站&#xff08;https://almalinux.org/&#xff09;下载最新版本的ISO文件。 安装VMware Workstation&#xff1a;确保您的计算机上已安装VMware Workstation。&#xff08;注&#xff1a;我这边使用的是VMware16…

中国联通首次推出一套量化大模型的新标准

新基准的诞生 中国联通的研究团队近日公布了一套创新性的量化标准&#xff0c;主要针对大型语言模型的能力评估。这一基准的灵感来源于动物智能演化的规律&#xff0c;为用户在选择语言模型时提供了科学依据。现代社会中&#xff0c;各种语言模型如雨后春笋般涌现&#xff0c;…

aardio —— 虚表 —— 使用ownerDrawCustom列类型制作喜马拉雅播放器列表

不会自绘也能做漂亮列表&#xff0c;你相信吗&#xff1f; 看看这个例子&#xff0c;虚表_vlistEx_ColType_OwnerDrawCustom列类型&#xff0c;移植自godking.customPlus&#xff0c;简单好用&#xff0c;做漂亮列表的大杀器&#xff0c;玩aardio必备利器&#xff01; 请更新…

网安数学基础期末复习

目录 整除同余同余方程群和环 整除 a的显然因数/平凡因数1&#xff0c;a整除的传递性和组合性 若 a ∣ b , b ∣ a a|b,b|a a∣b,b∣a 则 a b a\pm b ab欧几里得带余除法 公因数和最大公因数在整除里的定义&#xff0c;最大公因数为1则两数互质&#xff0c;注意公因数有正…

【论文阅读笔记】SCI算法与代码 | 低照度图像增强 | 2022.4.21

目录 一 SCI 1 SCI网络结构 核心代码&#xff08;model.py&#xff09; 2 SCI损失函数 核心代码&#xff08;loss.py&#xff09; 3 实验 二 SCI效果 1 下载代码 2 运行 一 SCI &#x1f49c;论文题目&#xff1a;Toward Fast, Flexible, and Robust Low-Light Image …

AcWing练习题:平均数2

读取三个浮点数 A&#xff0c;B 和 C 的值&#xff0c;对应于三个学生的成绩。 请你计算学生的平均分&#xff0c;其中 A 的成绩的权重为 2&#xff0c;B 的成绩的权重为 3&#xff0c;C 的成绩的权值为 5。 成绩的取值范围在 0 到 10 之间&#xff0c;且均保留一位小数。 输…

aardio —— 改变按钮文本颜色

import win.ui; /*DSG{{*/ var winform win.form(text"改变按钮颜色示例";right279;bottom239;composited1) winform.add( button{cls"button";text"点这里1";left16;top104;right261;bottom159;fontLOGFONT(h-14);z1}; button2{cls"butto…

Scratch教学作品 | 白水急流——急流勇进,挑战反应极限! ‍♂️

今天为大家推荐一款刺激又好玩的Scratch冒险作品——《白水急流》&#xff01;由AgentFransidium制作&#xff0c;这款作品将带你体验惊险的急流救援任务&#xff0c;帮助那位“睡着的疯狂人”安全穿越湍急水域&#xff01;想要挑战自己的反应极限&#xff1f;快来试试吧&#…

Android测试ABD环境及语句

1、什么是adb ADB 全称为 Android Debug Bridge&#xff0c;起到调试桥的作用&#xff0c;是一个客户端-服务器端程序。其中客户端是用来操作的电脑&#xff0c;服务端是 Android 设备。 ADB 也是 Android SDK 中的一个工具&#xff0c;可以直接操作管理 Android 模拟器或者真…

库伦值自动化功耗测试工具

1. 功能介绍 PlatformPower工具可以自动化测试不同场景的功耗电流&#xff0c;并可导出为excel文件便于测试结果分析查看。测试同时便于后续根据需求拓展其他自动化测试用例。 主要原理&#xff1a;基于文件节点 coulomb_count 实现&#xff0c;计算公式&#xff1a;电流&…