PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

PyTorch AMP 混合精度中的 scale 函数解析

混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale 函数是 GradScaler 的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。

本文将详细解析 scale 函数的作用、代码逻辑,以及 apply_scale 子函数的递归作用。


函数代码回顾

以下是 scale 函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

torch 2.4.0+cu121版本

def scale(self,outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:"""Multiplies ('scales') a tensor or list of tensors by the scale factor.Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returnedunmodified.Args:outputs (Tensor or iterable of Tensors):  Outputs to scale."""if not self._enabled:return outputs# Short-circuit for the common case.if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scaledef apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)

1. 函数作用

scale 函数的主要作用是将输出张量(outputs)按当前的缩放因子(self._scale)进行缩放。它支持以下两种输入:

  1. 单个张量:直接将缩放因子乘以张量。
  2. 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。

当 AMP 功能未启用时(即 self._enabledFalse),scale 函数会直接返回原始的 outputs,不执行任何缩放操作。

使用场景

  • 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
  • 支持多设备:通过 _MultiDeviceReplicator 支持张量分布在多个设备(如多 GPU)的场景。

2. 核心代码解析

(1) 短路处理单个张量

当输入 outputs 是单个张量(torch.Tensor)时,函数直接对其进行缩放:

if isinstance(outputs, torch.Tensor):if self._scale is None:self._lazy_init_scale_growth_tracker(outputs.device)assert self._scale is not Nonereturn outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
  1. 如果缩放因子 self._scale 尚未初始化,则调用 _lazy_init_scale_growth_tracker 方法在指定设备上初始化缩放因子。
  2. 使用 outputs * self._scale 对张量进行缩放。这里使用了 to(device=outputs.device) 确保缩放因子与张量在同一设备上。

这是单个张量输入的快速路径处理。


(2) 多张量递归处理逻辑

当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale 进行递归缩放:

stash: List[_MultiDeviceReplicator] = []  # 用于存储缩放因子对象def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):if isinstance(val, torch.Tensor):if len(stash) == 0:if self._scale is None:self._lazy_init_scale_growth_tracker(val.device)assert self._scale is not Nonestash.append(_MultiDeviceReplicator(self._scale))return val * stash[0].get(val.device)if isinstance(val, abc.Iterable):iterable = map(apply_scale, val)if isinstance(val, (list, tuple)):return type(val)(iterable)return iterableraise ValueError("outputs must be a Tensor or an iterable of Tensors")return apply_scale(outputs)
apply_scale 子函数的作用
  1. 张量处理

    • 如果 val 是单个张量,检查 stash 是否为空。
    • 如果为空,初始化缩放因子对象 _MultiDeviceReplicator,并存储在 stash 中。
    • 使用 stash[0].get(val.device) 获取对应设备上的缩放因子,并对张量进行缩放。
  2. 递归处理可迭代对象

    • 如果 val 是一个可迭代对象,调用 map(apply_scale, val),对其中的每个元素递归地调用 apply_scale
    • 如果输入是 listtuple,则保持其原始类型。
  3. 类型检查

    • 如果 val 既不是张量也不是可迭代对象,抛出错误。

3. apply_scale 是递归函数吗?

是的,apply_scale 是一个递归函数。

递归逻辑

  • 当输入为嵌套结构(如张量的列表或列表中的列表)时,apply_scale 会递归调用自身,将缩放因子应用到最底层的张量。
  • 递归的终止条件是 val 为单个张量(torch.Tensor)。
示例:

假设输入为嵌套张量列表:

outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)

递归处理过程如下:

  1. outputs 调用 apply_scale

    • 第一个元素是张量 torch.tensor([1.0, 2.0]),直接缩放。
    • 第二个元素是列表,递归调用 apply_scale
  2. 进入嵌套列表 [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]

    • 第一个元素是张量 torch.tensor([3.0]),缩放。
    • 第二个元素是张量 torch.tensor([4.0, 5.0]),缩放。

4. _MultiDeviceReplicator 的作用

_MultiDeviceReplicator 是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。

  • 当张量分布在多个设备(如 GPU)时,_MultiDeviceReplicator 可以高效地为每个设备提供所需的缩放因子,避免重复初始化。

总结

scale 函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale,该函数能够处理嵌套的张量结构,同时支持多设备场景。

关键点总结:

  1. 快速路径:单张量输入的情况下,直接进行缩放。
  2. 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
  3. 设备管理:通过 _MultiDeviceReplicator 支持多设备场景。

通过 scale 函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。

后记

2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/65985.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL图形化界面工具--DataGrip

之前介绍了在命令行进行操作,但是不够直观,本次介绍图形化界面工具–DataGrip。 安装DataGrip 官网链接:官网下载链接 常规的软件安装流程。 参考链接:DataGrip安装 使用DataGrip 添加数据源: 第一次使用最下面会…

el-table树形懒加载展开改为点击行展开

思路&#xff1a;获取el-table中小箭头&#xff0c;然后调它的click事件&#xff01; <el-tablerow-click"getOpenDetail":row-class-name"tableRowClassName">// 点击当前行展开节点getOpenDetail(row, column, event) {// 如果是叶子节点或点击的是…

【虚拟机】VMware 16图文安装和配置 AlmaLinux OS 9.5 教程

准备工作 下载AlmaLinux ISO文件&#xff1a;从AlmaLinux官方网站&#xff08;https://almalinux.org/&#xff09;下载最新版本的ISO文件。 安装VMware Workstation&#xff1a;确保您的计算机上已安装VMware Workstation。&#xff08;注&#xff1a;我这边使用的是VMware16…

中国联通首次推出一套量化大模型的新标准

新基准的诞生 中国联通的研究团队近日公布了一套创新性的量化标准&#xff0c;主要针对大型语言模型的能力评估。这一基准的灵感来源于动物智能演化的规律&#xff0c;为用户在选择语言模型时提供了科学依据。现代社会中&#xff0c;各种语言模型如雨后春笋般涌现&#xff0c;…

aardio —— 虚表 —— 使用ownerDrawCustom列类型制作喜马拉雅播放器列表

不会自绘也能做漂亮列表&#xff0c;你相信吗&#xff1f; 看看这个例子&#xff0c;虚表_vlistEx_ColType_OwnerDrawCustom列类型&#xff0c;移植自godking.customPlus&#xff0c;简单好用&#xff0c;做漂亮列表的大杀器&#xff0c;玩aardio必备利器&#xff01; 请更新…

网安数学基础期末复习

目录 整除同余同余方程群和环 整除 a的显然因数/平凡因数1&#xff0c;a整除的传递性和组合性 若 a ∣ b , b ∣ a a|b,b|a a∣b,b∣a 则 a b a\pm b ab欧几里得带余除法 公因数和最大公因数在整除里的定义&#xff0c;最大公因数为1则两数互质&#xff0c;注意公因数有正…

【论文阅读笔记】SCI算法与代码 | 低照度图像增强 | 2022.4.21

目录 一 SCI 1 SCI网络结构 核心代码&#xff08;model.py&#xff09; 2 SCI损失函数 核心代码&#xff08;loss.py&#xff09; 3 实验 二 SCI效果 1 下载代码 2 运行 一 SCI &#x1f49c;论文题目&#xff1a;Toward Fast, Flexible, and Robust Low-Light Image …

AcWing练习题:平均数2

读取三个浮点数 A&#xff0c;B 和 C 的值&#xff0c;对应于三个学生的成绩。 请你计算学生的平均分&#xff0c;其中 A 的成绩的权重为 2&#xff0c;B 的成绩的权重为 3&#xff0c;C 的成绩的权值为 5。 成绩的取值范围在 0 到 10 之间&#xff0c;且均保留一位小数。 输…

aardio —— 改变按钮文本颜色

import win.ui; /*DSG{{*/ var winform win.form(text"改变按钮颜色示例";right279;bottom239;composited1) winform.add( button{cls"button";text"点这里1";left16;top104;right261;bottom159;fontLOGFONT(h-14);z1}; button2{cls"butto…

蓝牙网关的传输距离有多远?

在物联网技术的快速发展中&#xff0c;蓝牙网关扮演着至关重要的角色&#xff0c;尤其是在扩展蓝牙设备的通信范围和连接能力方面。桂花网作为蓝牙网关的重要供应商&#xff0c;其产品在市场上得到了广泛的认可。那么小编今天带大家来了解下桂花网蓝牙网关的传输距离有多远&…

Scratch教学作品 | 白水急流——急流勇进,挑战反应极限! ‍♂️

今天为大家推荐一款刺激又好玩的Scratch冒险作品——《白水急流》&#xff01;由AgentFransidium制作&#xff0c;这款作品将带你体验惊险的急流救援任务&#xff0c;帮助那位“睡着的疯狂人”安全穿越湍急水域&#xff01;想要挑战自己的反应极限&#xff1f;快来试试吧&#…

YoloV8改进策略:Block改进|MCA,用于图像识别的深度卷积神经网络中的多维协同注意力|即插即用

摘要 论文介绍 研究背景:论文讨论了现有注意力模块(如ECA、SRM、CBAM等)在图像识别中的局限性,指出它们往往只关注通道间关系或空间维度中的特征相互作用,而忽略了它们之间的相关性。研究目的:旨在提出一种能够同时在通道、高度和宽度维度上学习互补注意力的方法,以提升…

【Infineon AURIX】AURIX缓存(CACHE)变量访问指南

AURIX缓存变量访问指南 引言 本文分析Infineon AURIX控制器在调试过程中访问缓存内存变量的问题及解决方案重点探讨了变量缓存对调试的影响以及多种解决方法的优劣第1部分:问题描述与成因分析 主要症状 变量值发生变化,但实时内存访问显示初始值Watch窗口和Memory窗口中的变…

【three.js】场景搭建

three.js由场景、相机、渲染器、灯光、控制器等几个要素组成。每个要素都有不同的类型&#xff0c;例如光照有太阳光、环境光、半球光等等。每种光照都有不同的属性可以进行配置。 场景 场景&#xff08;scene&#xff09;&#xff1a;场景是所有物体的容器&#xff0c;如果要…

CSS 图片廊:网页设计的艺术与技巧

CSS 图片廊&#xff1a;网页设计的艺术与技巧 引言 在网页设计中&#xff0c;图片廊是一个重要的组成部分&#xff0c;它能够以视觉吸引的方式展示图片集合&#xff0c;增强用户的浏览体验。CSS&#xff08;层叠样式表&#xff09;作为网页设计的主要语言之一&#xff0c;提供…

Android测试ABD环境及语句

1、什么是adb ADB 全称为 Android Debug Bridge&#xff0c;起到调试桥的作用&#xff0c;是一个客户端-服务器端程序。其中客户端是用来操作的电脑&#xff0c;服务端是 Android 设备。 ADB 也是 Android SDK 中的一个工具&#xff0c;可以直接操作管理 Android 模拟器或者真…

库伦值自动化功耗测试工具

1. 功能介绍 PlatformPower工具可以自动化测试不同场景的功耗电流&#xff0c;并可导出为excel文件便于测试结果分析查看。测试同时便于后续根据需求拓展其他自动化测试用例。 主要原理&#xff1a;基于文件节点 coulomb_count 实现&#xff0c;计算公式&#xff1a;电流&…

creating-custom-commands-in-flask

在烧瓶中创建自定义命令 原文:https://www . geesforgeks . org/creating-custom-commands-in-flask/ 本文围绕如何在 flask 中创建自定义命令展开。每次使用烧瓶运行运行烧瓶时&#xff0c;运行实际上是一个命令&#xff0c;在烧瓶配置文件中启动一个名为运行的函数。同样&…

机器学习基础-机器学习的常用学习方法

半监督学习的概念 少量有标签样本和大量有标签样本进行学习&#xff1b;这种方法旨在利用未标注数据中的结构信息来提高模型性能&#xff0c;尤其是在标注数据获取成本高昂或困难的情况下。 规则学习的概念 基本概念 机器学习里的规则 若......则...... 解释&#xff1a;如果…

深入解析希尔排序:原理、实现与优化

目录 一、希尔排序的基本思想 二、希尔排序的时间复杂度 三、优化与改进 希尔排序&#xff08;Shell Sort&#xff09;是一种基于插入排序的排序算法&#xff0c;其改进在于通过分组&#xff08;也叫增量&#xff09;的方式来减少数据移动的次数&#xff0c;从而提高了排序的…