Pytorch自动混合精度的计算:torch.cuda.amp.autocast

1 autocast介绍

1.1 什么是AMP?

默认情况下,大多数深度学习框架都采用32位浮点算法进行训练。2017年,NVIDIA研究了一种用于混合精度训练的方法,该方法在训练网络时将单精度(FP32)与半精度(FP16)结合在一起,并使用相同的超参数实现了与FP32几乎相同的精度。

FP16也即半精度是一种计算机使用的二进制浮点数据类型,使用2字节存储。而FLOAT就是FP32。

1.2 autocast作用

torch.cuda.amp.autocast是PyTorch中一种混合精度的技术(仅在GPU上训练时可使用),可在保持数值精度的情况下提高训练速度和减少显存占用。

    def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):

它是一个自动类型转换器,可以根据输入数据的类型自动选择合适的精度进行计算,从而使得计算速度更快,同时也能够节省显存的使用。使用autocast可以避免在模型训练过程中手动进行类型转换,减少了代码实现的复杂性。

在深度学习中,通常会使用浮点数进行计算,但是浮点数需要占用更多的显存,而低精度数值可以在减少精度的同时,减少缓存使用量。因此,对于正向传播和反向传播中的大多数计算,可以使用低精度型的数值,提高内存使用效率,进而提高模型的训练速度。

1.3 autocast原理

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

核心代码:autocast_mode.cpp

2 autocast优缺点

PyTorch中的autocast功能是一个性能优化工具,它可以自动调整某些操作的数据类型以提高效率。具体来说,它允许自动将数据类型从32位浮点(float32)转换为16位浮点(float16),这通常在使用深度学习模型进行训练时使用。

2.1 autocast优点

  • 提高性能:使用16位浮点数(half precision)进行计算可以在支持的硬件上显著提高性能,特别是在最新的GPU上。

  • 减少内存占用:16位浮点数占用的内存比32位少,这意味着在相同的内存限制下可以训练更大的模型或使用更大的批量大小。

  • 自动管理autocast能够自动管理何时使用16位浮点数,何时使用32位浮点数,这降低了手动管理数据类型的复杂性。

  • 保持精度:尽管使用了较低的精度,但autocast通常能够维持足够的数值精度,对最终模型的准确度影响不大。

2.2 autocast缺点

  • 硬件要求:并非所有的GPU都支持16位浮点数的高效运算。在不支持或优化不足的硬件上,使用autocast可能不会带来性能提升。

  • 精度问题:虽然在大多数情况下精度损失不显著,但在某些应用中,尤其是涉及到小数值或非常大的数值范围时,降低精度可能会导致问题。

  • 调试复杂性:由于autocast在模型的不同部分自动切换数据类型,这可能会在调试时增加额外的复杂性。

  • 算法限制:某些特定的算法或操作可能不适合在16位精度下运行,或者在半精度下的实现可能还不成熟。

  • 兼容性问题:某些PyTorch的特性或第三方库可能还不完全支持半精度运算。

在实际应用中,是否使用autocast通常取决于特定任务的需求、所使用的硬件以及对性能和精度的权衡。通常,对于大多数现代深度学习应用,特别是在使用最新的GPU时,使用autocast可以带来显著的性能优势。

3 使用示例

3.1 autocast混合精度计算

with autocast(): 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算,并且这里使用了GPU进行加速。使用示例如下:

# 导入相关库
import torch
from torch.cuda.amp import autocast# 定义一个模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = torch.nn.Linear(10, 1)def forward(self, x):with autocast():x = self.linear(x)return x# 初始化数据和模型
x = torch.randn(1, 10).cuda()
model = MyModel().cuda()# 进行前向传播
with autocast():output = model(x)# 计算损失
loss = output.sum()# 反向传播
loss.backward()

3.2 autocast与GradScaler一起使用

因为autocast会损失部分精度,从而导致梯度消失的问题,并且经过中间层时可能计算得到inf导致最终loss出现nan。所以我们通常将GradScaler与autocast配合使用来对梯度值进行一些放缩,来缓解上述的一些问题。

from torch.cuda.amp import autocast, GradScalerdataloader = ...
model = Model.cuda(0)
optimizer = ...
scheduler = ...
scaler = GradScaler()  # 新建GradScale对象,用于放缩
for epoch_idx in range(epochs):for batch_idx, (dataset) in enumerate(dataloader):optimizer.zero_grad()dataset = dataset.cuda(0)with autocast():  # 自动混精度logits = model(dataset)loss = ...scaler.scale(loss).backward()  # scaler实现的反向误差传播scaler.step(optimizer)  # 优化器中的值也需要放缩scaler.update()  # 更新scalerscheduler.step()
...

4 可能出现的问题

使用autocast技术进行混精度训练时loss经常会出现'nan',有以下三种可能原因:

  • 精度损失,有效位数减少,导致输出时数据末位的值被省去,最终出现nan的现象。该情况可以使用GradScaler(上文所示)来解决。
  • 损失函数中使用了log等形式的函数,或是变量出现在了分母中,并且训练时,该数值变得非常小时,混精度可能会让该值更接近0或是等于0,导致了数学上的log(0)或是x/0的情况出现,从而出现'inf'或'nan'的问题。这种时候需要针对该问题设置一个确定值。例如:当log(x)出现-inf的时候,我们直接将输出中该位置的-inf设置为-100,即可解决这一问题。
  • 模型内部存在的问题,比如模型过深,本身梯度回传时值已经非常小。这种问题难以解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/143655.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Alter database open fails with ORA-00600 kcratr_nab_less_than_odr

Alter database open fails with ORA-00600 kcratr_nab_less_than_odr (Doc ID 1296264.1)​编辑To Bottom APPLIES TO: Oracle Database - Enterprise Edition - Version 11.2.0.1 to 11.2.0.1 [Release 11.2] Oracle Database - Enterprise Edition - Version 12.1.0.1 to …

servlet 的XML Schema从哪边获取

servlet 6.0的规范定义: https://jakarta.ee/specifications/servlet/6.0/ 其中包含的三个XML Schema:web-app_6_0.xsd、web-common_6_0.xsd、web-fragment_6_0.xsd。但这个页面没有给出下载的链接地址。 正好我本机有Tomcat 10.1.15版本的源码&#…

深入解析JavaScript中的变量作用域与声明提升

JS中的变量作用域 背景: ​ 之前做js逆向的时候,有一个网站很有意思,就是先出现对其赋值,但是后来的变量赋值没有对其发生修改,决定说一下js中的作用域问题. 全局作用域: ​ 全局作用域的变量可以在任何…

PDF自动打印

​ 最近接到用户提过来的需求,需要一个能够自动打印图纸的功能,经过几天的研究整出来个初版了的,分享出来给大家,希望能有帮助。 需求描述: ​ 生产车间现场每天都有大量的图纸需要打印,一个一个打印太慢了&#xff0…

什么是3D建模中的“高模”和“低模”?

3D建模中什么是高多边形和低多边形? 高多边形建模和低多边形建模之间的主要区别正如其名称所暗示的那样:您是否在模型中使用大量多边形或少量多边形。 然而,在决定每个模型的细节和多边形级别时,还需要考虑其他事项。最值得注意的…

一文解码语言模型:语言模型的原理、实战与评估

在本文中,我们深入探讨了语言模型的内部工作机制,从基础模型到大规模的变种,并分析了各种评价指标的优缺点。文章通过代码示例、算法细节和最新研究,提供了一份全面而深入的视角,旨在帮助读者更准确地理解和评估语言模…

NI USRP软件无线设备的特点

NI USRP软件无线设备 NI的USRP(Universal Software Radio Peripheral)设备是RF应用中使用的软件无线(SDR)。NI的USRP收发器可以在多个频段发送和接收RF信号,因此可用于通信工程教育和研究。通过与LabVIEW开发环境相结合,USRP可以实现使用无线信号验证无…

接口开放太麻烦?试试阿里云API网关吧

前言 我在多方合作时,系统间的交互是怎么做的?这篇文章中写过一些多方合作时接口的调用规则和例子,然而,接口开放所涉及的安全、权限、监控、流量控制等问题,可不是简简单单就可以解决的,这一般需要专业的…

使用pixy计算群体遗传学统计量

1 数据过滤 过滤参数:过滤掉次等位基因频率(minor allele frequency,MAF)低于0.05、哈达-温伯格平衡(Hardy– Weinberg equilibrium,HWE)对应的P值低于1e-10或杂合率(heterozygosit…

【科研新手指南3】chatgpt辅助论文优化表达

chatgpt辅助论文优化表达 写在最前面最终版什么是好的论文整体上:逻辑/连贯性细节上一些具体的修改例子 一些建议,包括具体的提问范例1. 明确你的需求2. 提供上下文信息3. 明确问题类型4. 测试不同建议5. 请求详细解释综合提问范例: 常规技巧…

Spring6(一):入门案例

文章目录 1. 概述1.1 Spring简介1.2 Spring 的狭义和广义1.3 Spring Framework特点1.4 Spring模块组成 2 入门2.1 构建模块2.2 程序开发2.2.1 引入依赖2.2.2 创建java类2.2.3 创建配置文件2.2.4 创建测试类测试 2.3 程序分析2.4 启用Log4j2日志框架2.4.1 引入Log4j2依赖2.4.2 加…

轻量封装WebGPU渲染系统示例<32>- 若干线框对象(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/WireframeEntityTest.ts 当前示例运行效果: 此示例基于此渲染系统实现,当前示例TypeScript源码如下: export class WireframeEntityTest {private mRsc…

人工智能基础_机器学习030_ElasticNet弹性网络_弹性回归的使用---人工智能工作笔记0070

然后我们再来看elastic-net弹性网络,之所以叫弹性是因为,他融合了L1和L2正则,可以看到 他的公式 公式中有L1正则和L2正则两个都在这个公式中 可以看到弹性网络,在很多特征互相联系的时候,非常有用,比如, 相关性,如果数学好,那么物理也好,如果语文好,那么英语也好 这种联系 正…

JZ22:链表中倒数第k个结点

JZ22:链表中倒数第k个结点 题目描述: 输入一个链表,输出该链表中倒数第k个结点。 示例1 输入: 1,{1,2,3,4,5} 返回值: {5} 分析: 快慢指针思想: 需要两个指针,快指针fast&…

python 基础语法 (常常容易漏掉)

同一行显示多条语句 python语法中要求缩进,但是同一行可以显示多条语句 在 Python 中,可以使用分号 (;) 将多个语句放在同一行上。这样可以在一行代码中执行多个语句,但需要注意代码的可读性和维护性。 x 5; y 10; z x y; print(z) 在…

使用c++程序,实现图像平移变换,图像缩放、图像裁剪、图像对角线镜像以及图像的旋转

数字图像处理–实验三A图像的基本变换 实验内容 A实验: (1)使用VC设计程序:实现图像平移变换,图像缩放、图像裁剪、图像对角线镜像。 (2)使用VC设计程序:对一幅高度与宽度均相等的…

linux 系统下文本编辑常用的命令

一、是什么 Vim是从 vi 发展出来的一个文本编辑器,代码补全、编译及错误跳转等方便编程的功能特别丰富,在程序员中被广泛使用。 简单的来说, vi 是老式的字处理器,不过功能已经很齐全了,但是还是有可以进步的地方 而…

Playwright UI 自动化测试实战

📢专注于分享软件测试干货内容,欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢交流讨论:欢迎加入我们一起学习!📢资源分享:耗时200小时精选的「软件测试」资…

Kylin-Server-V10-SP3+Gbase+宝兰德信创环境搭建

目录 一、Kylin-Server-V10-SP3 安装1.官网下载安装包2.创建 VMware ESXi 虚拟机3.加载镜像,安装系统 二、Gbase 安装1.下载 Gbase 安装包2.创建组和用户、设置密码3.创建目录4.解压包5.安装6.创建实例7.登录8.常见问题 三、宝兰德安装1.获取安装包2.解压安装3.启动…