在K40小破卡训练50层BERT Large的宝藏trick

前言

虽然TPU的显存令人羡慕,但是由于众所周知的原因,绝大部分人还是很难日常化使用的。英伟达又一直在挤牙膏,至今单卡的最大显存也仅仅到32G(参考V100、DGX-2)。然而,训练一个24层的BERT Large模型的时候,如果sequence length开满512,那么batch size仅仅开到8(有时候能到10)就把这寥寥32G的显存打满了。如果想训练一个48层乃至100层的BERT Large,那完全是土豪们的游戏了,需要疯狂的模型并行+分布式多机训练。

但!是!万能的小夕前不久在Daxiang Dong大佬的安利下,发现了 @陈天奇 大佬2016年的一篇宝藏paper!

v2-f49c8e494a2e751c9d0af79768528a50_b.jpg

传送门:arxiv.org/pdf/1604.0617

简单的划一下重点:

这篇paper用时间换空间的思想,在前向时只保存部分中间节点,在反向时重新计算没保存的部分。论文通过这种机制,在每个batch只多计算一次前向的情况下,把n层网络的占用显存优化到了 ( √)O(n)。在极端情况下,仍可用 ( )O(nlogn)的计算时间换取到 ( )O(logn)的显存占用。在论文的实验中,他们成功将将1000层的残差网络从48G优化到了7G。且,这种方法同样可以直接应用于RNN结构中。

看完摘要,瞬间感觉在小破卡上训练BERT Large有救了!!!

此外,来快速过一遍paper中最重要的三点结论:

  1. 梯度计算等价,理论上没有精度损失

v2-5587763e2ba1543d78cf9bb09d40f34e_b.jpeg

2. 可以节省4倍+的显存开销

v2-d229b37fea16b7b96c6a328a2f321dbe_b.jpg

3. 训练速度仅仅会被拖慢30%

v2-116f907db1c1be19fe3bc4715604dad6_b.jpg

不过论文发表在2016年,当时还没有BERT,不过Baidu Paddle团队补了一个BERT的实验结果,发现在BERT上面只用22.5%的训练速度损失就能换来5倍+的显存开销节省!相关实验在本文末尾,不着急,接下来我们先一起分析一下在训练阶段时显存为什么容易不足。

感谢Baidu Paddle团队提供本节图文素材和测试数据

训练阶段显存为何不足

深度学习中,网络的一次训练包含前向计算、后向计算和优化三个步骤。

v2-3979e52c54fa910f8545a19c6c69329c_b.jpg

在这个过程中,前向计算会输出大量的隐层变量Tensor,当模型层数加深时,Tensor数量可达成千上万个。如Bert Large模型,单个Tensor可达到1GB,这些Tensor在显存中累积,显存很快就爆掉了

下图是Bert Large模型在一次训练过程中的显存使用情况,可以明显看到在前向计算过程中,显存累积趋势是一个陡峭的上升直线。而在反向计算过程中,这些隐层Tensor又会很快地被消耗掉,又是一个陡峭的下降曲线,显存直接降到低位。

v2-368a254252a1a1e603f8e7c1f73134f9_b.jpg

那么问题来了,为什么不直接删除这些前向计算的Tensor呢?

答案很简单,因为这些隐层的Tensor在反向的时会被用到(手动狗头

来个简单的证明。

假设前向计算中有一个矩阵乘法计算:

Y = W × X

对W求梯度:

v2-0a590e0246a9026561604e1ba0a87720_b.png

很容易发现,对W求梯度的公式里有X,而X就是那个巨能吃显存的隐层Tensor!

那我们是否可以暂时扔掉这些隐层Tensor,在反向计算时再把它们重新生成出来呢?当然可以,这正是上面这篇paper的思想。

重计算

顾名思义,"重计算"就是让每个训练迭代过程做两次前向计算,看起来有点奇怪,实际上却非常有效!对于刚刚那个吃显存的Bert Large,支持重计算机制后,显存占用直接从175GB降低到20GB,陡峭的显存上升直线变成了缓慢增长的Z形曲线,如下图所示。

v2-8b2e33037ced305aff661d0920f93e1e_b.jpg

核心思想是将前向计算分割成多个段,将每个段的起始Tensor作为这个段的检查点(checkpoints)。前向计算时,除了检查点以外的其他隐层Tensor占有的显存可以及时释放。反向计算用到这些隐层Tensor时,从前一个检查点开始,重新进行这个段的前向计算,就可以重新获得隐层Tensor。

重计算机制有点像玩单机游戏。每过一个关卡就会保存一个检查点,而隐层Tensor就相当于游戏中任何一个时刻的图像。普通的训练方式是打通一遍游戏,并且将游戏中所有时刻的图像保存下来;而重计算机制的思路是先把游戏通关,保存检查点,后面当收到某一时刻图像的请求时,再重打一遍这一关卡就可以了。

v2-ffbf879f3f5b9ad5056a892e7badadaf_b.jpg

如下图,举一个简单的例子,添加重计算机制前,前向计算中需要存储的隐层是4个红点;添加重计算机制后,需要存储的隐层变为2个蓝点, 从而节省了这部分内存。

v2-3bd3128214da4f8573d25e101892f960_b.jpg

虽然时间也是宝贵的,但重计算方法的性价比很高。在论文的实验中,作者用30%的计算时间换取了4倍的内存空间。并且重计算只是重复了一次前向的过程,理论上精度没有任何损失

这么宝藏的算法当然也少不了开源实现。

开源实现

调研了一波,似乎TF没有原生支持,但是生态里有OpenAI的第三方实现;pytorch和paddlepaddle中都有原生API支持

  • Pytorch:
    • torch.utils.checkpoint
  • PaddlePaddle:
    • optimizer.RecomputeOptimizer

不过pytorch的文档比较略,也没有提供更细致的示例和相关数据,有兴趣的小伙伴自行试一下。paddle框架中提供了详细到哭的文档,甚至还有一个现成的BERT+重计算的例子,以及非常详细的实验测试结果。这里直接贴过来(真香系列

Paddle中实现显存重计算大体分为三步:

  1. 定义一个经典的优化器,如SGD优化器;
  2. 在外面包一层重计算优化器;
  3. 设置检查点。

以MLP为例,只需要增加两行代码就可以进入重计算模式

        import paddle.fluid as fluid
# 定义MLP
def mlp(input_x, input_y, hid_dim=128, label_dim=2):print(input_x)fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')cost = fluid.layers.cross_entropy(input=prediction, label=input_y)sum_cost = fluid.layers.reduce_mean(cost)return sum_cost, fc_1, predictioninput_x = fluid.layers.data(name="x", shape=[32], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)# 定义RecomputeOptimizersgd = fluid.optimizer.SGD(learning_rate=0.01)
recompute_optimizer = fluid.optimizer.RecomputeOptimizer(sgd)
# 设置checkpoints
recompute_optimizer._set_checkpoints([fc_1, pred])
# 运行优化算法
recompute_optimizer.minimize(cost)

该示例github链接:

github.com/PaddlePaddle

此外,官方还给出了一个BERT中做重计算的示例

github链接:
github.com/PaddlePaddle

BERT实验结论(划重点

根据上面paddle官方提供的BERT示例和实验结果,得出以下几个结论

结论一

在32GB显存的Tesla V100显卡上应用重计算机制,可以训练更大、更深的深度学习模型。当num_tokens为4096(batch size=32,seqlen=128)时,可以训练100层的Bert网络。

v2-c2ecacc5a1eb24b484b0c124deb623fb_b.jpg

从Github的实验结果也可以看出,显存上的收益比速度的损失要大很多:

v2-9bad0123f620fa2766881e69c00078ef_b.jpg

在batch_size上提升了5倍,速度只降低了约1/5,且精度没有损失。

结论二

模型训练的batch size最大可提升为原来的5倍+,且只有少量的速度损失。

重计算机制在Bert Large这一模型上收益最大,最大batch size从93提升到562!而在VGG-16这种比较浅的模型上,重计算机制的收益则比较小。这充分符合重计算机制的设计理念:为了训练更大、更深的模型。

结论三

在古董显卡Tesla K40显卡(12G显存)上,训练BERT Large时batch size可以开到130

v2-e098ec96feff953fcd691fa7594da340_b.jpg

最后,希望本文可以帮助大家在小破卡上尽情训练BERT Large~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480654.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

领域应用 | 推荐算法不够精准?让知识图谱来解决

本文转载自公众号:微软研究院AI头条。 编者按:我们几乎每天都会接收到各种各样的推荐信息,从新闻、购物到吃饭、娱乐。个性化推荐系统作为一种信息过滤的重要手段,可以依据我们的习惯和爱好推荐合适的服务。但传统的推荐系统容易出…

BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)

一只小狐狸带你解锁 炼丹术&NLP 秘籍作者:夕小瑶、rumor酱前言虽然TPU的显存令人羡慕,但是由于众所周知的原因,绝大部分人还是很难日常化使用的。英伟达又一直在挤牙膏,至今单卡的最大显存也仅仅到32G(参考V100、D…

去腾讯等BAT面试完的Mysql面试55题总结,含答案大赠送!

【BAT面试:mysql 面试55题】 1、一张表里面有ID自增主键,当insert了17条记录之后,删除了第15,16,17条记录,再把mysql重启,再insert一条记录,这条记录的ID是18还是15 ? 2、mysql的技术特点是什…

这十套练习,教你如何使用Pandas做数据分析

这十套练习,教你如何用Pandas做数据分析Pandas是入门Python做数据分析所必须要掌握的一个库。本文内容由科赛网翻译整理自Github,建议读者完成科赛网 从零上手Python关键代码 和 Pandas基础命令速查表 教程学习的之后,点击本篇Notebook右上角…

预训练模型超全知识点梳理与面试必备高频FAQ

一只小狐狸带你解锁 炼丹术&NLP 秘籍作者:JayLou娄杰来源:https://zhuanlan.zhihu.com/p/115014536预训练模型(Pre-trained Models,PTMs)的出现将NLP带入了一个全新时代。2020年3月18日,邱锡鹏老师发表了关于NLP预训练模型的综述《Pre-tr…

阿里P8架构师谈:MySQL数据库的索引原理、与慢SQL优化的5大原则

MySQL凭借着出色的性能、低廉的成本、丰富的资源,已经成为绝大多数互联网公司的首选关系型数据库。虽然性能出色,但所谓“好马配好鞍”,如何能够更好的使用它,已经成为开发工程师的必修课,我们经常会从职位描述上看到诸…

论文浅尝 | 利用类比推理优化知识图谱向量表示

链接:https://arxiv.org/pdf/1705.02426.pdf本文的主要创新点就是把类比推理应用到 KG embedding 中,通过对模型的 score function 添加某些约束来捕获 KG 中类比结构的信息,进而优化 KG 中实体和关系的 embedding 表示,并在 FB15…

推荐 10 个饱受好评且功能独特的开源人工智能项目

来自:开源中国链接:https://my.oschina.net/editorial-story/blog/1592254推荐 10 个饱受好评且功能独特的开源人工智能项目关于人工智能的项目,相信大家都看过或者用过不少了,但它们的大多数看上去都十分“高大上”,让…

如何以初学者角度写好一篇国际学术论文?

一只小狐狸带你解锁 炼丹术&NLP 秘籍人工智能顶会论文之争越来越激烈了,CVPR、AAAI、ICLR等各大会议虽然录取率逐年降低,但是投稿论文数量却在逐年增加。虽说发论文不是衡量一位学者的学术能力的唯一标准,但确是极为重要的标准。一篇好的…

领域应用 | 如何将知识图谱特征学习应用到推荐系统?

本文转载自公众号:微软研究院AI头条。 编者按:在上周发表的“推荐算法不够精准?让知识图谱来解决”一文中,我们为大家介绍了日常生活中几乎每天都会用到的推荐系统,以及用来提高推荐系统精准性、多样性和可解释性的推荐…

阿里P8架构师谈:MySQL行锁、表锁、悲观锁、乐观锁的特点与应用

我们在操作数据库的时候,可能会由于并发问题而引起的数据的不一致性(数据冲突)。如何保证数据并发访问的一致性、有效性,是所有数据库必须解决的一个问题,锁的冲突也是影响数据库并发访问性能的一个重要因素&#xff0…

谷歌、微软、OpenAI等巨头七大机器学习开源项目 看这篇就够了

在人工智能行业,2015-2016 出现了一个不同寻常的趋势:许多重量级机器学习项目纷纷走向开源,与全世界的开发者共享。加入这开源大潮的,不仅有学界师生,更有国内外的互联网巨头们:国内有百度和腾讯&#xff0…

推荐系统的发展与简单回顾

“本文结合百度和支付宝两段推荐系统相关的实习经历,针对工业界的模型发展做了简单梳理与回顾,涵盖表示学习,深度学习,强化学习知识图谱以及多任务学习”表示学习和深度学习在推荐系统中的应用是目前工业界比较成熟的,但是与强化学…

论文浅尝 | 嵌入常识知识的注意力 LSTM 模型用于特定目标的基于侧面的情感分析...

MaY, Peng H, Cambria E. Targeted aspect-based sentiment analysis via embedding commonsense knowledge into an attentive LSTM[C]//AAAI. 2018.任务简介特定目标的基于侧面的情感分析,在原来基于侧面的情感分析的基础上,进一步挖掘细粒度的信息&am…

阿里P8架构师谈:MySQL有哪些存储引擎,各自的优缺点,应用场景

经常面试都会问到MYSQL有哪些存储引擎,以及各自的优缺点。今天主要分享常见的存储引擎:MyISAM、InnoDB、MERGE、MEMORY(HEAP)、BDB(BerkeleyDB)等,以及最常用的MyISAM与InnoDB两个引擎 &#xf…

TensorFlow 全网最全学习资料汇总之TensorFlow的技术应用

谷歌于2015年11月发布了全新人工智能系统TensorFlow。该系统可被用于语音识别或照片识别等多项机器深度学习领域,主要针对2011年开发的深度学习基础架构DistBelief进行了各方面的改进,它可在小到一部智能手机、大到数千台数据中心服务器的各种设备上运行…

13个offer,8家SSP,谈谈我的秋招经验

本文转载自公众号“夕小瑶的卖萌屋”,专业带逛互联网算法圈的神操作 -----》我是传送门 关注后,回复以下口令: 回复【789】 :领取深度学习全栈手册(含NLP、CV海量综述、必刷论文解读) 回复【入群】&#xf…

领域应用 | 知识图谱的技术与应用

本文转载自公众号:贪心科技。作者 | 李文哲,人工智能、知识图谱领域专家导读:从一开始的Google搜索,到现在的聊天机器人、大数据风控、证券投资、智能医疗、自适应教育、推荐系统,无一不跟知识图谱相关。它在技术领域的…

阿里P8架构师谈:MySQL慢查询优化、索引优化、以及表等优化总结

MySQL优化概述 MySQL数据库常见的两个瓶颈是:CPU和I/O的瓶颈。 CPU在饱和的时候一般发生在数据装入内存或从磁盘上读取数据时候。 磁盘I/O瓶颈发生在装入数据远大于内存容量的时候,如果应用分布在网络上,那么查询量相当大的时候那么平瓶颈就…

医药领域知识图谱快速及医药问答项目

QABasedOnMedicaKnowledgeGraph self-implement of disease centered Medical graph from zero to full and sever as question answering base. 从无到有搭建一个以疾病为中心的一定规模医药领域知识图谱,并以该知识图谱完成自动问答与分析服务。 项目介绍 本项…