PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/51324.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python】自然语言处理(NLP)技术简介

紧紧握着 青花信物 信守着承诺 离别总在 失意中度过 记忆油膏 反复涂抹 无法愈合的伤口 你的回头 划伤了沉默 🎵 周传雄《青花》 自然语言处理(NLP)技术是一种使计算机能够理解和处理人类自然语言的技术。以下是一些NLP…

appium学习记录

免责声明 本文内容仅供参考,将appuim与爬虫技术相结合可能违反某些app的使用条款和法律法规。作者不对因此产生的法律问题或技术风险负责。建议读者在进行爬取操作前,充分了解相关法律法规并确保合规。 1、初识appium 背景:部分APP需要反编译…

C#用户控件usercontrol中的子控件事件及属性的传递

也不知道这个标题怎么写,但是问题是个老问题,大家都可能遇到过,不过有同学问到,那就写出来。其实很简单。只不过有的同学看了其他博文后脑子还是懵懵的。所以这里就分两部分来说明一下。 文章目录 一、属性的传递1、原理2、步骤3…

tensorflow新建op (cpp)

为什么使用cpp新建op 一些操作表示成现有操作的组合不好实现或者无法实现。已有操作的组合效率不高。想要自定义一些基本操作的组合,因为未来编译器做这种融合可能会比较困难。 如何使用cpp新建op 注册op,注册op会定义一个接口(规范&#…

Mac M1Pro 安装Java性能监控工具VisualVM 2.1.9

本地已经安装了java8,在终端输入jvisualvm提示没有安装 zhiniansara ~ % jvisualvm The operation couldn’t be completed. Unable to locate a Java Runtime that supports jvisualvm. Please visit http://www.java.com for information on installing Java.官网…

RPA自动化流程机器人助力企业财务数字化转型

在数字经济时代,企业需要快速响应市场变化,而财务数字化转型是企业适应现代商业环境、提升竞争力的必要步骤。财务数字化转型不仅涉及企业财务能力的提升,推动了财务管理与决策模式的转变。RPA自动化流程机器人因其能通过自动化技术帮助企业实…

[云计算] 虚拟化笔记

原著: 韩冰,[云计算课程], 有删改。 目的 对 IT 资源简化,用户通过标准接口访问。 资源是提高一定功能的实现 。可以是硬件, 如CPU, 也可以是软件。 发展史 1961 IBM CPU 分时间片, 一个CPU 虚拟化为多…

【Nature】在科研中应用ChatGPT:如何与数据对话

随着人工智能技术的迅猛发展,大型语言模型(LLMs)正逐渐成为科研领域的一种创新工具。这些模型通过自然语言处理技术,使得研究人员能够以直观的方式与数据进行交互,从而简化了数据分析和解释的过程。在《自然》杂志2024…

Matlab自学笔记三十四:表table的排序、查找、提取、删除、计算、与结构数组的转换

1.表格的统计分析 表的统计分析包括计算均值、方差等,这些参数可以通过函数summary一次计算出来,程序示例如下: xingming{zhangsan;lisi;wangwu}; %首先创建表变量 xuehao{1001;1002;1003}; chengji[89 95;90 87;88 84]; ttable(xingmin…

当外接硬盘接入到macOS上,只读不可写时,应当格式化

当windows磁盘格式例如 NTFS 的硬盘接入到macOS上时,会发现无法新建文件夹,无法删除、重命名。原因是磁盘格式对不上macOS,需要进行格式化。格式化时请注意备份重要数据。具体做法如下,在macOS中找到磁盘工具,然后对磁…

QT Quick QML 实例之定制 TableView

QT Quick QML 实例之定制 TableView 一、演示二、C关键步骤1. beginInsertRows()(用户插入行)2. roleNames() (表格中列映射)3. data() (用户获取数据)4. headerData() (表头)5. fla…

影视会员官方渠道api对接

API对接是指两个不同的软件系统或应用程序之间通过API(应用程序编程接口)进行交互的过程。这种交互允许数据和功能的共享,而不必暴露系统的内部工作原理。在影视会员充值场景中,API对接具有以下几个关键特点和优势: 数…

【从Qwen2,Apple Intelligence Foundation,Gemma 2,Llama 3.1看大模型的性能提升之路】

从早期的 GPT 模型到如今复杂的开放式 LLM,大型语言模型 (LLM) 的发展已经取得了长足的进步。最初,LLM 训练过程仅侧重于预训练,但后来扩展到包括预训练和后训练。后训练通常包括监督指令微调和校准,这是由 ChatGPT 推广的。 自 …

11、Redis高级:Key设置、BigKey解决、批处理优化、集群下批处理、慢查询

Redis高级篇之最佳实践 今日内容 Redis键值设计批处理优化服务端优化集群最佳实践 1、Redis键值设计 1.1、优雅的key结构 Redis的Key虽然可以自定义,但最好遵循下面的几个最佳实践约定: 遵循基本格式:[业务名称]:[数据名]:[id]长度不超过…

浅说数据链

一支军队能否制胜战场?影响因素有很多,高效的信息采集、传送、交换就是其中之一。从冷兵器时代的流星探马、八百里加急,到绵延千里的烽火狼烟;从近现代战场上“滴滴、滴滴滴”声不断的电报,到枪林弹雨中官兵手中的电话…

沉浸式解压小视频在哪找?非常减压的几个视频素材网站分享

沉浸式解压小视频,以其独特的舒缓音乐、宁静自然景观和柔和动态图像,成为了迅速消解压力的有效途径。这些视频能够帮助我们暂时离开紧张的现实,重获内心的平和。如果你正在寻找优质的解压视频素材,不用担心,接下来我会…

【HarmonyOS NEXT星河版开发学习】综合测试案例-各平台评论部分

目录 前言 功能展示 整体页面布局 最新和最热 写评论 点赞功能 界面构建 初始数据的准备 列表项部分的渲染 底部区域 index部分 知识点概述 List组件 List组件简介 ListItem组件详解 ListItemGroup组件介绍 ForEach循环渲染 列表分割线设置 列表排列方向设…

图像分割论文阅读:BCU-Net: Bridging ConvNeXt and U-Net for medical image segmentation

本文提出了一种集合ConvNeXt和U-Net优势的网络模型来分割医学图像。 当然,模型整体结构就是并列双分支,如果只是这些内容,不值得拿出来讲。 主要有意思的部分是其融合两分支的多标签召回模块(multilabel recall loss module&…

如何使用midjourney?MidJourney订阅计划及国内订阅教程

国内如何订阅MidJourney 第三方代理 参考: zhangfeidezhu.com/?p474 使用信用卡订阅教程 办理国外信用卡: 这个各自找国外的银行办理就好了。 登录MidJourney: 登录MidJourney网站,进入订阅中心。如果是在Discord频道&#x…

ES 模糊查询 wildcard 的替代方案探索

一、Wildcard 概述 Wildcard 是一种支持通配符的模糊检索方式。在 Elasticsearch 中,它使用星号 * 代表零个或多个字符,问号 ? 代表单个字符。 其使用方式多样,例如可以通过 {"wildcard": {"field_name": "value&…