PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

PyTorch AMP 混合精度中的 scale 函数解析

混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale 函数是 GradScaler 的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。

本文将详细解析 scale 函数的作用、代码逻辑,以及 apply_scale 子函数的递归作用。


函数代码回顾

以下是 scale 函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

torch 2.4.0+cu121版本

def scale(self,outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:"""Multiplies ('scales') a tensor or list of tensors by the scale factor.Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returnedunmodified.Args:outputs (Tensor or iterable of Tensors):  Outputs to scale."""if not self._enabled:return outputs# Short-circuit for the common case.if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scaledef apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)

1. 函数作用

scale 函数的主要作用是将输出张量(outputs)按当前的缩放因子(self._scale)进行缩放。它支持以下两种输入:

  1. 单个张量:直接将缩放因子乘以张量。
  2. 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。

当 AMP 功能未启用时(即 self._enabledFalse),scale 函数会直接返回原始的 outputs,不执行任何缩放操作。

使用场景

  • 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
  • 支持多设备:通过 _MultiDeviceReplicator 支持张量分布在多个设备(如多 GPU)的场景。

2. 核心代码解析

(1) 短路处理单个张量

当输入 outputs 是单个张量(torch.Tensor)时,函数直接对其进行缩放:

if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
  1. 如果缩放因子 self._scale 尚未初始化,则调用 _lazy_init_scale_growth_tracker 方法在指定设备上初始化缩放因子。
  2. 使用 outputs * self._scale 对张量进行缩放。这里使用了 to(device=outputs.device) 确保缩放因子与张量在同一设备上。

这是单个张量输入的快速路径处理。


(2) 多张量递归处理逻辑

当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale 进行递归缩放:

stash: List[_MultiDeviceReplicator] = []  # 用于存储缩放因子对象def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
apply_scale 子函数的作用
  1. 张量处理

    • 如果 val 是单个张量,检查 stash 是否为空。
    • 如果为空,初始化缩放因子对象 _MultiDeviceReplicator,并存储在 stash 中。
    • 使用 stash[0].get(val.device) 获取对应设备上的缩放因子,并对张量进行缩放。
  2. 递归处理可迭代对象

    • 如果 val 是一个可迭代对象,调用 map(apply_scale, val),对其中的每个元素递归地调用 apply_scale
    • 如果输入是 listtuple,则保持其原始类型。
  3. 类型检查

    • 如果 val 既不是张量也不是可迭代对象,抛出错误。

3. apply_scale 是递归函数吗?

是的,apply_scale 是一个递归函数。

递归逻辑

  • 当输入为嵌套结构(如张量的列表或列表中的列表)时,apply_scale 会递归调用自身,将缩放因子应用到最底层的张量。
  • 递归的终止条件是 val 为单个张量(torch.Tensor)。
示例:

假设输入为嵌套张量列表:

outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)

递归处理过程如下:

  1. outputs 调用 apply_scale

    • 第一个元素是张量 torch.tensor([1.0, 2.0]),直接缩放。
    • 第二个元素是列表,递归调用 apply_scale
  2. 进入嵌套列表 [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]

    • 第一个元素是张量 torch.tensor([3.0]),缩放。
    • 第二个元素是张量 torch.tensor([4.0, 5.0]),缩放。

4. _MultiDeviceReplicator 的作用

_MultiDeviceReplicator 是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。

  • 当张量分布在多个设备(如 GPU)时,_MultiDeviceReplicator 可以高效地为每个设备提供所需的缩放因子,避免重复初始化。

总结

scale 函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale,该函数能够处理嵌套的张量结构,同时支持多设备场景。

关键点总结:

  1. 快速路径:单张量输入的情况下,直接进行缩放。
  2. 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
  3. 设备管理:通过 _MultiDeviceReplicator 支持多设备场景。

通过 scale 函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。

后记

2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/65985.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL图形化界面工具--DataGrip

之前介绍了在命令行进行操作,但是不够直观,本次介绍图形化界面工具–DataGrip。 安装DataGrip 官网链接:官网下载链接 常规的软件安装流程。 参考链接:DataGrip安装 使用DataGrip 添加数据源: 第一次使用最下面会…

【虚拟机】VMware 16图文安装和配置 AlmaLinux OS 9.5 教程

准备工作 下载AlmaLinux ISO文件:从AlmaLinux官方网站(https://almalinux.org/)下载最新版本的ISO文件。 安装VMware Workstation:确保您的计算机上已安装VMware Workstation。(注:我这边使用的是VMware16…

中国联通首次推出一套量化大模型的新标准

新基准的诞生 中国联通的研究团队近日公布了一套创新性的量化标准,主要针对大型语言模型的能力评估。这一基准的灵感来源于动物智能演化的规律,为用户在选择语言模型时提供了科学依据。现代社会中,各种语言模型如雨后春笋般涌现,…

aardio —— 虚表 —— 使用ownerDrawCustom列类型制作喜马拉雅播放器列表

不会自绘也能做漂亮列表,你相信吗? 看看这个例子,虚表_vlistEx_ColType_OwnerDrawCustom列类型,移植自godking.customPlus,简单好用,做漂亮列表的大杀器,玩aardio必备利器! 请更新…

网安数学基础期末复习

目录 整除同余同余方程群和环 整除 a的显然因数/平凡因数1,a整除的传递性和组合性 若 a ∣ b , b ∣ a a|b,b|a a∣b,b∣a 则 a b a\pm b ab欧几里得带余除法 公因数和最大公因数在整除里的定义,最大公因数为1则两数互质,注意公因数有正…

【论文阅读笔记】SCI算法与代码 | 低照度图像增强 | 2022.4.21

目录 一 SCI 1 SCI网络结构 核心代码(model.py) 2 SCI损失函数 核心代码(loss.py) 3 实验 二 SCI效果 1 下载代码 2 运行 一 SCI 💜论文题目:Toward Fast, Flexible, and Robust Low-Light Image …

AcWing练习题:平均数2

读取三个浮点数 A,B 和 C 的值,对应于三个学生的成绩。 请你计算学生的平均分,其中 A 的成绩的权重为 2,B 的成绩的权重为 3,C 的成绩的权值为 5。 成绩的取值范围在 0 到 10 之间,且均保留一位小数。 输…

aardio —— 改变按钮文本颜色

import win.ui; /*DSG{{*/ var winform win.form(text"改变按钮颜色示例";right279;bottom239;composited1) winform.add( button{cls"button";text"点这里1";left16;top104;right261;bottom159;fontLOGFONT(h-14);z1}; button2{cls"butto…

Scratch教学作品 | 白水急流——急流勇进,挑战反应极限! ‍♂️

今天为大家推荐一款刺激又好玩的Scratch冒险作品——《白水急流》!由AgentFransidium制作,这款作品将带你体验惊险的急流救援任务,帮助那位“睡着的疯狂人”安全穿越湍急水域!想要挑战自己的反应极限?快来试试吧&#…

Android测试ABD环境及语句

1、什么是adb ADB 全称为 Android Debug Bridge,起到调试桥的作用,是一个客户端-服务器端程序。其中客户端是用来操作的电脑,服务端是 Android 设备。 ADB 也是 Android SDK 中的一个工具,可以直接操作管理 Android 模拟器或者真…

库伦值自动化功耗测试工具

1. 功能介绍 PlatformPower工具可以自动化测试不同场景的功耗电流,并可导出为excel文件便于测试结果分析查看。测试同时便于后续根据需求拓展其他自动化测试用例。 主要原理:基于文件节点 coulomb_count 实现,计算公式:电流&…

creating-custom-commands-in-flask

在烧瓶中创建自定义命令 原文:https://www . geesforgeks . org/creating-custom-commands-in-flask/ 本文围绕如何在 flask 中创建自定义命令展开。每次使用烧瓶运行运行烧瓶时,运行实际上是一个命令,在烧瓶配置文件中启动一个名为运行的函数。同样&…

机器学习基础-机器学习的常用学习方法

半监督学习的概念 少量有标签样本和大量有标签样本进行学习;这种方法旨在利用未标注数据中的结构信息来提高模型性能,尤其是在标注数据获取成本高昂或困难的情况下。 规则学习的概念 基本概念 机器学习里的规则 若......则...... 解释:如果…

python使用AprilTag 3

python使用AprilTag 3 最近想测试一下AprilTag精度,看看能不能用的上。 1 安装 法1:github源码编译安装(放弃) 一开始找到了AprilTag 3的官方github网址https://github.com/AprilRobotics/apriltag,但是按着操作下…

小程序学习07—— uniapp组件通信props和$emit和插槽语法

目录 一 父组件向子组件传递消息 1.1 props (a)传递静态或动态的 Prop (b)单向数据流 二 子组件通知父组件 2.1 $emit (a)定义自定义事件 (b)绑定自定义事件 三 插槽语法…

纵览!报表控件 Stimulsoft Reports、Dashboards 和 Forms 2025.1 新版本发布!

Stimulsoft 2025.1 新版发布,旨在增强您创建报告、仪表板和 PDF 表单的体验!此最新版本为您带来了许多改进和新功能,使数据处理更加高效和用户友好。亮点包括对 .NET 9 的支持、Microsoft Analysis Services 的新数据适配器、发布向导中适用于…

Unity Pico 应用失去焦点后,追踪功能被禁用(原生 UI 界面弹出)

在 Unity 中,如果正在使用新的输入系统,任何触发 OnApplicationFocus(false) 的事件都可能会禁用追踪功能。 负责此功能的组件是附加到主摄像机的 "Tracked Pose Driver (Input System)" 组件。由于非输入系统版本不是新输入系统的一部分&…

面试准备备备备

职业技能 放到简历的黄金位置(HR刷选简历的重要参考) 基本准则:写在简历上的必须能聊,不然就别写 参考公式:职业技能 必要技术 其他技术 针对性的引导面试官(让他问一些你想让他问的) 寻找合…

多光谱图像的处理和分析方法有哪些?

一、预处理方法 1、辐射校正: 目的:消除或减少传感器本身、大气条件以及太阳光照等因素对多光谱图像辐射亮度值的影响,使得图像的辐射值能够真实反映地物的反射或发射特性。 方法:包括传感器校正和大气校正。传感器校正主要是根…

艾体宝方案丨全面提升API安全:AccuKnox 接口漏洞预防与修复

一、API 安全:现代企业的必修课 在现代技术生态中,应用程序编程接口(API)扮演着不可或缺的角色。从数据共享到跨平台集成,API 成为连接企业系统与外部服务的桥梁。然而,伴随云计算的普及与微服务架构的流行…