pytorch笔记:自动混合精度(AMP)

1 理论部分

1.1 FP16 VS FP32

  • FP32具有八个指数位和23个小数位,而FP16具有五个指数位和十个小数位
  • Tensor内核支持混合精度数学,即输入为半精度(FP16),输出为全精度(FP32)

1.1.1 使用FP16的优缺点

  • 优点
    • FP16需要较少的内存,因此更易于训练和部署大型神经网络,同时还减少了数据移动(同时可以使用更大的batch)
    • 数学运算的运行速度大大降低了
      • NVIDIA提供的Volta GPU的确切数量是:FP16中为125 TFlops,而FP32中为15.7 TFlops(加速8倍)
  • 缺点:
    • 从FP32转到FP16时,必然会降低精度
      • 但有的时候,这个精度的降低可以忽略不计
      • FP16实际上可以很好地表示大多数权重和渐变。
      • ——>拥有存储和使用FP32所需的所有这些额外位只是浪费。
    • 溢出错误
      • 由于FP16的动态范围比FP32位的狭窄很多,因此,在计算过程中很容易出现上溢出和下溢出
      • 溢出之后就会出现"NaN"的问题

1.2 解决上述FP16的问题

1.2.1 混合精度训练

  • 用FP16做储存和乘法,而用FP32做累加避免舍入误差
  • ——>混合精度训练的策略有效地缓解了舍入误差的问题

1.2.2 损失放大(Loss scaling)

  • 即使使用了混合精度训练,还是存在无法收敛的情况
    • 原因是激活梯度的值太小,造成了溢出。
  • ——>通过使用torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的下溢出
    • 只在BP时传递梯度信息使用,真正更新权重时还是要把放大的梯度再unscale回去
      • 反向传播前,将损失变化手动增大2^k倍

        • 因此反向传播时得到的中间变量(激活函数梯度)不会溢出;

      • 反向传播后,将权重梯度缩小2^k倍,恢复正常值。

2 torch.cuda.amp

  • AMP(自动混合精度)的关键词有两个:
    • 自动
      • Tensor的dtype类型会自动变化,框架按需自动调整tensor的dtype,当然有些地方还需手动干预
    • 混合精度
      • 采用不止一种精度的Tensor,torch.FloatTensor和torch.HalfTensor

2.1 Pytorch中不同类型的tensor

类型名称位数
torch.DoubleTensor64bit
torch.LongTensor64bit
torch.FloatTensor(默认)32bit
torch.IntTensor32bit
torch.HalfTensor16bit
torch.BFloat16Tensor16bit
torch.ShortTensor16bit
torch.ByteTensor(无符号)8bit
torch.CharTensor8bit
torch.BoolTensorBoolean

2.2 在AMP上下文中,被自动转化为半精度浮点型的参数:

__matmul__
addbmm
addmm
addmv
addr
baddbmm
bmm
chain_matmul
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
linear
matmul
mm
mv
prelu

2.3 autocast

from torch.cuda.amp import autocast as autocastmodel = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensoroptimizer = optim.SGD(model.parameters(), ...)for input, target in data:optimizer.zero_grad()with autocast():output = model(input)loss = loss_fn(output, target)'''自动混合精度环境包含了前向过程(模型的输出)和loss的计算把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算进入autocast的上下文时,tensor可以是任何类型不需要在model或者input上手工调用.half() ,框架会自动做'''loss.backward()optimizer.step()# 反向传播在autocast上下文之外

 2.4 GradScaler

在2.3的基础上增加,反向传播时增加梯度,以防止下溢出

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScalermodel = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensoroptimizer = optim.SGD(model.parameters(), ...)scaler = GradScaler()
# 在训练最开始之前实例化一个GradScaler对象for epoch in epochs:for input, target in data:optimizer.zero_grad()with autocast():output = model(input)loss = loss_fn(output, target)'''自动混合精度环境包含了前向过程(模型的输出)和loss的计算把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算进入autocast的上下文时,tensor可以是任何类型不需要在model或者input上手工调用.half() ,框架会自动做'''scaler.scale(loss).backward()# Scales loss. 为了梯度放大,防止下溢出# 代替原来的loss.backward()scaler.step(optimizer)'''scaler.step() 首先把梯度的值unscale回来.如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新(不被破坏)'''scaler.update()'''准备着,看是否要增大scaler'''
  •  scaler的大小在每次迭代中动态的估计
    • 为了尽可能的减少梯度underflow,scaler应该更大
    • 但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。
  • ——>动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值

3 一些tips

  • 为了保证计算不溢出,首先保证人工设定的常数不溢出。如epsilon,INF等
  • Dimension最好是8的倍数:维度是8的倍数,性能最好
  • 涉及sum的操作要小心,容易溢出
    • 比如softmax操作,建议用官方API,并定义成layer写在模型初始化里
  • 如果遇到以下的报错:
    • RuntimeError: expected scalar type float but found c10::Half
    • 需要手动在tensor上调用.float()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/23894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL主从同步优化指南:架构、瓶颈与解决方案

前言 ​ 在现代数据库架构中,MySQL 主从同步是实现高可用性和负载均衡的关键技术。本文将深入探讨主从同步的架构、延迟原因以及优化策略,并提供专业的监控建议。 MySQL 主从同步架构 ​ 主从复制流程: 从库生成两个线程,一个…

20 - 每月交易 I(高频 SQL 50 题基础版)

20 - 每月交易 I -- 考点:日期转换格式 -- date_format(trans_date,%Y-%m)select date_format(trans_date,%Y-%m) month,country,count(*) trans_count,sum(if(stateapproved,1,0)) approved_count,sum(amount) trans_total_amount,sum(if(state"approved&qu…

【主题广泛|稳定检索】2024年食品安全与生物技术国际会议(ICFSB 2024)

2024年食品安全与生物技术国际会议(ICFSB 2024) 2024 International Conference on Food Safety and Biotechnology 【重要信息】 大会地点:贵阳 大会官网:http://www.icicfsb.com 投稿邮箱:icicfsbsub-conf.com 【注…

语言大模型qwen1.5全流程解析:微调,量化与推理

在前一篇文章中,主要使用llama-factory封装的推理模块对速度进行了测试,vllm速度快些,但仍没有传说中的快3-5倍,需要单独测试。这里使用qwen1.5-1.8B作为测试模型。 qwen1.5是qwen2的先行版,24年2月发布,与…

jenkins插件之Jdepend

JDepend插件是一个为构建生成JDepend报告的插件。 安装插件 JDepend Dashboard -->> 系统管理 -->> 插件管理 -->> Available plugins 搜索 Jdepend, 点击安装构建步骤新增执行shell #执行pdepend if docker exec phpfpm82 /tmp/composer/vendor/bin/pdepe…

ComfyUI工作流分享-黏土特效工作流

大家给的教程都是苹果端使用Remini的软件制作,免费白嫖7天,7天后就要收费,作为ComfyUI技术党,当然是选择自己实现了,搭建一套工作流就搞定,这不,今天就来分享一套对应的黏土效果工作流&#xff…

TSINGSEE青犀视频:城市道路积水智能监管,智慧城市的守护者

随着城市化进程的加快,城市道路网络日益复杂,尤其在夏季,由于暴雨频发,道路积水问题成为影响城市交通和市民生活的重要因素之一。传统的道路积水监测方式往往依赖于人工巡逻和简单的监控设备,这些方法存在效率低下、响…

CAN总线学习笔记-CAN帧结构

数据帧 数据帧:发送设备主动发送数据(广播式) 标准格式的11ID不够用了,由此产生了扩展格式 SOF:帧起始,表示后面一段波形为传输的数据位 ID:标识符,区分功能,同时决定优…

先进制造aps专题十一 国内软件/erp行业的现状及对aps行业的启示

看到一个帖子 中国软件行业几乎全军覆没 OSC开源社区 2024-06-03 15:58 广东 刚刚网上冲浪刷到的 网友锐评:都是客户关系型公司。 知名大 V 「Fenng」评论称: 这里所谓的软件行业公司如果立刻倒闭,才能够利好中国整个行业软件生态。有个网…

巨详细Linux安装Nacos教程

巨详细Linux安装Nacos教程 1、检查是否有残留nacos版本2、上传安装包至服务器2.1安装包获取2.2创建相关目录 3、安装Nacos4、配置Nacos4.1修改数据源4.2新建nacos数据库4.3启动nacos4.4把nacos进程交给systemctl管理4.5设置nacos开机自启动 1、检查是否有残留nacos版本 rpm -q…

Unity基础实践小项目

项目流程: 需求分析 开始界面 选择角色面板 排行榜面板 设置面板 游戏面板 确定退出面板 死亡面板 UML类图 准备工作 1.导入资源 2.创建需要的文件夹 3.创建好面板基类 开始场景 开始界面 1.拼面板 2.写脚本 注意事项:注意先设置NGUI的分辨率大小&…

问题:律师会见委托人的方式包括团体会见和( )。 #职场发展#笔记#学习方法

问题:律师会见委托人的方式包括团体会见和( )。 参考答案如图所示

【Python报错】已解决TypeError: can only concatenate str (not “int“) to str

成功解决“TypeError: can only concatenate str (not “int”) to str”错误的全面指南 一、引言 在Python编程中,字符串(str)和整数(int)是两种基本的数据类型。然而,当我们尝试将这两种类型的对象进行连…

[matlab]折线图之多条折线如何绘制实心圆作为标记点

使用MarkerFaceColor是标记点填充的颜色,b,表示blue,蓝色 plot(x, a, d--, MarkerFaceColor, b); % 绘制仿真结果的曲线如果一张图多条曲线那么每条曲线需要单独调用一次plot,每个plot间用hold on 连接 plot(x, a, d--, MarkerF…

通配符SSL证书快速申请攻略

一、什么是通配符SSL证书 通配符SSL证书又叫泛域名SSL证书,可以保护一个主域名及其所有二级子域名,并对该级子域名数量无限制,且添加新的该级子域名无须另外付费。 二、通配符SSL证书有哪些优势 1.节省时间和金钱:与购买和安装…

Spring Boot + URule 实现可视化规则引擎,太优雅了!

Spring Boot URule 实现可视化规则引擎,太优雅了! 一、背景二、介绍三、安装使用四、基础概念整体介绍库文件变量库文件常量库文件参数库文件动作库文件规则集向导式规则集脚本式规则集 决策表其他 五、运用场景六、总结 一、背景 前段时间&#xff0c…

2、Tomcat 线程模型详解

2、Tomcat 线程模型详解 Tomcat I/O模型详解Linux I/O模型详解I/O要解决什么问题Linux的I/O模型分类 Tomcat支持的 I/O 模型Tomcat I/O 模型如何选型 网络编程模型Reactor线程模型单 Reactor 单线程单 Reactor 多线程主从 Reactor 多线程 Tomcat NIO实现Tomcat 异步IO实现 Tomc…

CentOs7 JDK21 安装

CentOs7 JDK21 安装 准备工作 先检查系统是否之前已经安装了jdk 。如果已经安装的话需要卸载。 方式一:使用压缩包的方式 下载jdk21的压缩包 https://www.oracle.com/java/technologies/downloads/ 将下载的gz压缩包上传到服务器并解压 # 创建文件夹 (你可以自…

java web如何调用py脚本文件

Controller public class IndexController {RequestMapping("/pythonTest")ResponseBodypublic String pythonTest(){// 假设你的Python脚本名为script.pyString pythonScriptPath "D:\\project\\c1\\hello.py";ProcessBuilder processBuilder new Proce…

C51学习归纳6 --- UART串口数据通信

这一部分我认为是十分重要的,没有了数据的传输,我们做的很多事情将是没有意义的。我们一般利用串口做两件事,一是单片机向电脑发送信息,二是单片机接收电脑的信息。 一、UART原理 TXD:发送信息通道,RXD: 接收信息通道。我发送你接…