模型训练太慢?显存不够用?这个算法让你的GPU老树开新花

一只小狐狸带你解锁NLP/ML/DL秘籍

作者:小鹿鹿鹿,夕小瑶

老板,咱们就一台Titan Xp,训不动BERT呀

没钱买机器,自己想办法。

委屈T^T

我听说混合精度训练可以从算法上缓解这个问题?

喵喵喵??

         

其实小夕的内心是拒绝的,就一台破Xp,再优化能快到哪里去呀T^T

燃鹅

 

小夕找了一份开源代码,结果刚开始跑小夕就震惊了!什么鬼?训练速度怎么这么快?出bug了吧????

一毛一样的模型、超参数和硬件环境,竟然可以获得2.X倍的加速。

 

关键的关键,这不是一个特例,在各类网络训练问题上都提速明显,遍地开花~~~      

混合精度训练

一切还要从2018年ICLR的一篇论文说起。。。

《MIXED PRECISION TRAINING》

 

这篇论文是百度&Nvidia研究院一起发表的,结合N卡底层计算优化,提出了一种灰常有效的神经网络训练加速方法,不仅是预训练,在全民finetune BERT的今天变得异常有用哇。而且小夕调研了一下,发现不仅百度的paddle框架支持混合精度训练,在Tensorflow和Pytorch中也有相应的实现。下面我们先来讲讲理论,后面再分析混合精度训练在三大深度学习框架中的打开方式

 

理论原理

训练过神经网络的小伙伴都知道,神经网络的参数和中间结果绝大部分都是单精度浮点数(即float32)存储和计算的,当网络变得超级大时,降低浮点数精度,比如使用半精度浮点数显然是提高计算速度,降低存储开销的一个很直接的办法。然而副作用也很显然,如果我们直接降低浮点数的精度直观上必然导致模型训练精度的损失。但是呢,天外有天,这篇文章用了三种机制有效地防止了模型的精度损失。待小夕一一说来o(* ̄▽ ̄*)ブ

 

权重备份(master weights)

我们知道半精度浮点数(float16)在计算机中的表示分为1bit的符号位,5bits的指数位和10bits的尾数位,所以它能表示的最小的正数即2^-24(也就是精度到此为止了)。当神经网络中的梯度灰常小的时候,网络训练过程中每一步的迭代(灰常小的梯度 ✖ 也黑小的learning rate)会变得更小,小到float16精度无法表示的时候,相应的梯度就无法得到更新。

 

论文统计了一下在Mandarin数据集上训练DeepSpeech 2模型时产生过的梯度,发现在未乘以learning rate之前,就有接近5%的梯度直接悲剧的变成0(精度比2^-24还要高的梯度会直接变成0),造成重大的损失呀/(ㄒoㄒ)/~~

      

还有更难的,假设迭代量逃过一劫准备奉献自己的时候。。。由于网络中的权重往往远大于我们要更新的量,当迭代量小于Float16当前区间内能表示的最小间隔的时候,更新也会失败(哭瞎┭┮﹏┭┮我怎么这么难鸭)

              

所以怎么办呢?作者这里提出了一个非常simple but effective的方法,就是前向传播和梯度计算都用float16,但是存储网络参数的梯度时要用float32!这样就可以一定程度上的解决上面说的两个问题啦~~~

 

我们来看一下训练曲线,蓝色的线是正常的float32精度训练曲线,橙色的线是使用float32存储网络参数的learning curve,绿色滴是不使用float32存储参数的曲线,两者一比就相形见绌啦。

     

损失放缩(loss scaling)

有了上面的master weights已经可以足够高精度的训练很多网络啦,但是有点强迫症的小夕来说怎么还是觉得有点不对呀o((⊙﹏⊙))o.

虽然使用float32来存储梯度,确实不会丢失精度了,但是计算过程中出现的指数位小于 -24 的梯度不还是会丢失的嘛!相当于用漏水的筛子从河边往村里运水,为了多存点水,村民们把储水的碗换成了大缸,燃鹅筛子依然是漏的哇,在路上的时候水就已经漏的木有了。。

 

于是loss scaling方法来了。首先作者统计了一下训练过程中激活函数梯度的分布情况,由于网络中的梯度往往都非常小,导致在使用FP16的时候右边有大量的范围是没有使用的。这种情况下, 我们可以通过放大loss来把整个梯度右移,减少因为精度随时变为0的梯度。

      

那么问题来了,怎么合理的放大loss呢?一个最简单的方法是常数缩放,把loss一股脑统一放大S倍。float16能表示的最大正数是2^15*(1+1-2^-10)=65504,我们可以统计网络中的梯度,计算出一个常数S,使得最大的梯度不超过float16能表示的最大整数即可。

 

当然啦,还有更加智能的动态调整(automatic scaling) o(* ̄▽ ̄*)ブ

我们先初始化一个很大的S,如果梯度溢出,我们就把S缩小为原来的二分之一;如果在很多次迭代中梯度都没有溢出,我们也可以尝试把S放大两倍。以此类推,实现动态的loss scaling。

              

运算精度(precison of ops)

精益求精再进一步,神经网络中的运算主要可以分为四大类,混合精度训练把一些有更高精度要求的运算,在计算过程中使用float32,存储的时候再转换为float16。

  • matrix multiplication: linear, matmul, bmm, conv

  • pointwise: relu, sigmoid, tanh, exp, log

  • reductions: batch norm, layer norm, sum, softmax

  • loss functions: cross entropy, l2 loss, weight decay

像矩阵乘法和绝大多数pointwise的计算可以直接使用float16来计算并存储,而reductions、loss function和一些pointwise(如exp,log,pow等函数值远大于变量的函数)需要更加精细的处理,所以在计算中使用用float32,再将结果转换为float16来存储。

 

总结陈词

混合精度训练做到了在前向和后向计算过程中均使用半精度浮点数,并且没有像之前的一些工作一样还引入额外超参,而且重要的是,实现非常简单却能带来非常显著的收益,在显存half以及速度double的情况下保持模型的精度,简直不能再厉害啦。

三大深度学习框架的打开方式

看完了硬核技术细节之后,我们赶紧来看看代码实现吧!如此强大的混合精度训练的代码实现不要太简单了吧????

Pytorch

导入Automatic Mixed Precision (AMP),不要998不要288,只需3行无痛使用!

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()

来看个例子,将上面三行按照正确的位置插入到自己原来的代码中就可以实现酷炫的半精度训练啦!

import torch
from apex import amp
model = ... 
optimizer = ...#包装model和optimizer
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")for data, label in data_iter: out = model(data) loss = criterion(out, label) optimizer.zero_grad() #loss scaling,代替loss.backward()with amp.scaled_loss(loss, optimizer) as scaled_loss:   scaled_loss.backward() 
optimizer.step()

Tensorflow

一句话实现混合精度训练之修改环境变量,在python脚本中设置环境变量

os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

 

除此之外,也可以用类似pytorch的方式来包装optimizer。

Graph-based示例

opt = tf.train.AdamOptimizer()#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt,loss_scale='dynamic')train_op = opt.miminize(loss)

Keras-based示例

opt = tf.keras.optimizers.Adam()#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt,loss_scale='dynamic')model.compile(loss=loss, optimizer=opt)
model.fit(...)

 

PaddlePaddle

一句话实现混合精度训练之添加config(惊呆????毕竟混合精度训练是百度家提出的,内部早就熟练应用了叭)

--use_fp16=true

举个栗子,基于BERT finetune XNLI任务时,只需在执行时设置use_fp16为true即可。

export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7BERT_BASE_PATH="chinese_L-12_H-768_A-12"
TASK_NAME='XNLI'
DATA_PATH=/path/to/xnli/data/
CKPT_PATH=/path/to/save/checkpoints/python -u run_classifier.py --task_name ${TASK_NAME} \--use_fp16=true \  #!!!!!!add a line--use_cuda true \--do_train true \--do_val true \--do_test true \--batch_size 32 \--in_tokens false \--init_pretraining_params ${BERT_BASE_PATH}/params \--data_dir ${DATA_PATH} \--vocab_path ${BERT_BASE_PATH}/vocab.txt \--checkpoints ${CKPT_PATH} \--save_steps 1000 \--weight_decay  0.01 \--warmup_proportion 0.1 \--validation_steps 100 \--epoch 3 \--max_seq_len 128 \--bert_config_path ${BERT_BASE_PATH}/bert_config.json \--learning_rate 5e-5 \--skip_steps 10 \--num_iteration_per_drop_scope 10 \--verbose true

  • 跨平台NLP/ML文章索引

  • 训练太慢,GPU利用率上不去?快来试试别人家的tricks

  • 高维空间的局部最优点分析

  • 神经网络调参指南

  • 如何与GPU服务器优雅交互

求关注 求投喂 拉你进高端群哦~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480941.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java】Junit、反射和注解的笔记

1 Junit 黑盒测试:不需要写代码,给输入值,看程序输出是否符合期望 白盒测试:需要写代码,关注程序具体的执行流程 Junit —> 白盒测试 步骤 定义一个测试类(测试用例) 【命名:类…

我对Spring的理解

1、什么是Spring? Spring是Java企业级应用的开源开发框架。Spring主要用来开发Java应用,但是有些扩展是针对构建J2EE平台的web应用。Spring框架目标是简化Java企业级应用开发,并通过POJO为基础的编程模型促进良好的编程习惯。 2、使用Spring…

k-means+python︱scikit-learn中的KMeans聚类实现( + MiniBatchKMeans)

版权声明&#xff1a;博主原创文章&#xff0c;微信公众号&#xff1a;素质云笔记,转载请注明来源“素质云博客”&#xff0c;谢谢合作&#xff01;&#xff01; https://blog.csdn.net/sinat_26917383/article/details/70240628 </div><link rel"stylesh…

想让推荐和搜索引擎更聪明?基于知识图谱的篇章标签生成

一只小狐狸带你解锁NLP/ML/DL秘籍正文来源&#xff1a;丁香园大数据NLP 老板&#xff5e;我们的推荐系统笨笨的你怎么对文档处理的这么糙&#xff01;抽个关键词就应付过去了&#xff1f;啊啊啊我错惹&#xff0c;那那&#xff0c;不用关键词用什么呢&#xff1f;知识图…

论文浅尝 | Dynamic Weighted Majority for Incremental Learning

Yang Lu , Yiu-ming Cheung , Yuan Yan Tang. Dynamic Weighted Majority for Incremental Learning ofImbalanced Data Streams with Concept Drift. In Proceedings of the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI-17)论文链接&…

【JavaWeb】数据库基础复习

1 MySQL 数据库特点&#xff1a; 持久化存储数据&#xff0c;数据库就是一个文件系统便于存储和管理数据使用统一的方式操作数据库 启动MySQL服务&#xff1a; 管理员cmd&#xff1a;net start mysql 停止MySQL服务&#xff1a; 管理员cmd&#xff1a;net stop mysql 打开服…

Python的多行输入与多行输出

因为在OJ上做编程&#xff0c;要求标准输入&#xff0c;特别是多行输入。特意查了资料&#xff0c;自己验证了可行性。if __name__ "__main__":strList []for line in sys.stdin: #当没有接受到输入结束信号就一直遍历每一行tempStr line.split()#对字符串利用空…

微服务Dubbo和SpringCloud架构设计、优劣势比较

一、微服务介绍 微服务架构是互联网很热门的话题&#xff0c;是互联网技术发展的必然结果。它提倡将单一应用程序划分成一组小的服务&#xff0c;服务之间互相协调、互相配合&#xff0c;为用户提供最终价值。虽然微服务架构没有公认的技术标准和规范或者草案&#xff0c;但业界…

搜索引擎核心技术与算法 —— 词项词典与倒排索引优化

一只小狐狸带你解锁NLP/ML/DL秘籍作者&#xff1a;QvQ老板&#xff5e;我会写倒排索引啦&#xff01;我要把它放进咱们自研搜索引擎啦&#xff01;我呸&#xff01;你这种demo级代码&#xff0c;都不够当单元测试的&#xff01;嘤嘤嘤&#xff0c;课本上就是这样讲的呀?!来来&…

论文浅尝 | Distant Supervision for Relation Extraction

Citation: Ji,G., Liu, K., He, S., & Zhao, J. (2017). Distant Supervision for RelationExtraction with Sentence-Level Attention and Entity Descriptions. Ai,3060–3066.动机关系抽取的远程监督方法通过知识库与非结构化文本对其的方式&#xff0c;自动标注数据&am…

使用sklearn做单机特征工程

目录 1 特征工程是什么&#xff1f;2 数据预处理  2.1 无量纲化    2.1.1 标准化    2.1.2 区间缩放法    2.1.3 标准化与归一化的区别  2.2 对定量特征二值化  2.3 对定性特征哑编码  2.4 缺失值计算  2.5 数据变换  2.6 回顾3 特征选择  3.1 Filte…

【JavaWeb】JDBC的基本操作和事务控制+登录和转账案例

1 JDBC操作数据库 1.1 连接数据库 首先导入jar包到lib public class JdbcDemo1 {public static void main(String[] args) throws ClassNotFoundException, SQLException {//1.注册驱动Class.forName("com.mysql.jdbc.Driver");//2.获取数据库连接对象Connection…

Restful、SOAP、RPC、SOA、微服务之间的区别

一、介绍Restful、SOAP、RPC、SOA以及微服务 1.1、什么是Restful&#xff1f; Restful是一种架构设计风格&#xff0c;提供了设计原则和约束条件&#xff0c;而不是架构&#xff0c;而满足这些约束条件和原则的应用程序或设计就是 Restful架构或服务。 主要的设计原则&#xf…

详解深度语义匹配模型DSSM和他的兄弟姐妹

一只小狐狸带你解锁NLP/ML/DL秘籍正文作者&#xff1a;郭耀华正文来源&#xff1a;https://www.cnblogs.com/guoyaohua/p/9229190.html前言在NLP领域&#xff0c;语义相似度的计算一直是个难题&#xff1a;搜索场景下Query和Doc的语义相似度、feeds场景下Doc和Doc的语义相似度、…

行业新闻 | 阿里发力知识图谱研究 悉数囊括顶尖学者探讨合作

12 月 20 日&#xff0c;阿里巴巴联合中国中文信息学会语言与知识计算专委会(KG专委)举办的知识图谱研讨会在杭州召开。研讨会由阿里巴巴集团副总裁墙辉&#xff08;花名&#xff1a;玄难&#xff09;主持&#xff0c;国内知识图谱领域多位顶级专家参加此次研讨会。在阿里巴巴持…

Kaggle入门,看这一篇就够了

New Update&#xff1a;之前发表了这篇关于 Kaggle 的专栏&#xff0c;旨在帮助对数据科学( Data Science )有兴趣的同学们更好的了解这个平台。专栏发表至今收到了不少的关注和肯定&#xff0c;还有很多小伙伴私信相关的问题。因此&#xff0c;我特邀了一位海外一线 Data Scie…

【JavaWeb】JDBC优化 之 数据库连接池、Spring JDBC

1 数据库连接池 为什么要使用数据库连接池&#xff1f; 数据库连接是一件费时的操作&#xff0c;连接池可以使多个操作共享一个连接使用连接池可以提高对数据库连接资源的管理节约资源且高效 概念&#xff1a;数据库连接池其实就是一个容器&#xff0c;存放数据库连接的容器…

Java远程通讯技术及原理分析

在分布式服务框架中&#xff0c;一个最基础的问题就是远程服务是怎么通讯的&#xff0c;在Java领域中有很多可实现远程通讯的技术&#xff0c;例如&#xff1a;RMI、MINA、ESB、Burlap、Hessian、SOAP、EJB和JMS等&#xff0c;这些名词之间到底是些什么关系呢&#xff0c;它们背…

CUDA层硬件debug之路

前记 众所周知&#xff0c;夕小瑶是个做NLP的小可爱。 虽然懂点DL框架层知识&#xff0c;懂点CUDA和底层&#xff0c;但是我是做算法的哎&#xff0c;平时debug很少会遇到深度学习框架层的bug&#xff08;上一次还是三年前被pytorch坑&#xff09;&#xff0c;更从没遇到过CUDA…

研讨会 | 知识图谱大咖云集阿里,他们都说了啥

前言12月20日&#xff0c;由阿里巴巴联合中国中文信息学会语言与知识计算专委会(KG专委)举办的知识图谱研讨会在杭州召开。研讨会由阿里巴巴集团副总裁墙辉&#xff08;玄难&#xff09;主持&#xff0c;知识图谱领域国内知名专家参与了此次研讨。在阿里巴巴持续发力知识图谱这…