Perseus-BERT——业内性能极致优化的BERT训练方案【阿里云弹性人工智能】

一,背景——横空出世的BERT全面超越人类

2018年在自然语言处理(NLP)领域最具爆炸性的一朵“蘑菇云”莫过于Google Research提出的BERT(Bidirectional Encoder Representations from Transformers)模型。作为一种新型的语言表示模型,BERT以“摧枯拉朽”之势横扫包括语言问答、理解、预测等各项NLP锦标的桂冠,见图1和图2。

【图1】SQuAD是基于Wikipedia文章的标准问答数据库的NLP锦标。目前SQuAD2.0排名前十名均为基于BERT的模型(图中列出前五名),前20名有16席均是出自BERT 

【图2】GLUE是一项通用语言理解评估的benchmark,包含11项NLP任务。BERT自诞生日起长期压倒性霸占榜首(目前BERT排名第二,第一为Microsoft提交的BIGBIRD模型,由于没有URL链接无从知晓模型细节,网传BIGBIRD的名称上有借鉴BERT BIG模型之嫌)

业内将BERT在自然语言处理的地位比作ResNet之于计算机视觉领域的里程碑地位。在BERT横空出世之后,所有的自然语言处理任务都可以基于BERT模型为基础展开。

一言以蔽之,现如今,作为NLP的研究者,如果不了解BERT,那就是落后的科技工作者;作为以自然语言处理为重要依托的科技公司,如果不落地BERT,那就是落后生产力的代表。

 

二,痛点——算力成为BERT落地的拦路虎

BERT强大的原因在哪里?让我们拂去云霭,窥探下硝烟下的奥秘。

BERT模型分为预训练模型(Pretrain)和精调模型(Finetune)。Pretrain模型为通用的语言模型。Finetune只需要在Pretrain的基础上增加一层适配层就可以服务于从问答到语言推理等各类任务,无需为具体任务修改整体模型架构,如图3所示。这种设计方便BERT预处理模型适配于各类具体NLP模型(类似于CV领域基于ImageNet训练的各种Backbone模型)。

【图3】左图基于BERT pretrain的模型用于语句问答任务(SQuAD)的finetune模型,右图为用于句对分类(Sentence Pair Classification Tasks)的finetune模型。他们均是在BERT Pretrain模型的基础上增加了一层具体任务的适配层

 

因此,BERT的强大主要归功于精确度和鲁棒性俱佳的Pretrain语言模型。大部分的计算量也出自Pretrain模型。其主要运用了以下两项技术,都是极其耗费计算资源的模块。

  1. 双向Transformer架构

图4可见,与其他pre-training的模型架构不同,BERT从左到右和从右到左地同时对语料进行transformer处理。这种双向技术能充分提取语料的时域相关性,但同时也大大增加了计算资源的负担。【关于Transformer是Google 17年在NLP上的大作,其用全Attention机制取代NLP常用的RNN及其变体LSTM等的常用架构,大大改善了NLP的预测准确度。本文不展开,该兴趣的同学可以自行搜索一下】。

【图4】Pretrain架构对比。其中OpenAI GPT采用从左到右的Transformer架构,ELMo采用部分从左到右和部分从右到左的LSTM的级联方式。BERT采用同时从左到右和从右到左的双向Transformer架构。

 

  1. 词/句双任务随机预测

BERT预训练模型在迭代计算中会同时进行单词预测和语句预测两项非监督预测任务。

其一,单词预测任务对语料进行随机MASK操作(Masked LM)。在所有语料中随机选取15%的单词作为Mask数据。被选中Mask的语料单词在迭代计算过程中80%时间会被掩码覆盖用于预测、10%时间保持不变、10%时间随机替换为其他单词,如图5所示。

其二,语句预测任务(Next Sentence Prediction)。对选中的前后句A和B,在整个迭代预测过程中,50%的时间B作为A的真实后续语句(Label=IsNext),另外50%的时间则从语料库里随机选取其他语句作为A的后续语句(Label=NotNext),如图5所示

【图5】词/句双任务随机预测输入语料实例。蓝框和红框为同一个语料输入在不同时刻的随机状态。对单词预测任务,蓝框中的“went”为真实数据,到了红框则被[MASK],红框中的“the” 则相反;对于语句预测任务,蓝框中的句组为真实的前后句,而红框中的句组则为随机的组合。

这种随机选取的单词/语句预测方式在功能上实现了非监督数据的输入的功能,有效防止模型的过拟合。但是按比例随机选取需要大大增加对语料库的迭代次数才能消化所有的语料数据,这给计算资源带来了极大的压力。

综上,BERT预处理模型功能需要建立在极强的计算力基础之上。BERT论文显示,训练BERT BASE 预训练模型(L=12, H=768, A=12, Total Parameters=110M, 1000,000次迭代)需要1台Cloud TPU工作16天;而作为目前深度学习主流的Nvidia GPU加速卡面对如此海量的计算量更是力不从心。即使是目前主流最强劲的Nvidia V100加速卡,训练一个BERT-Base Pretrain模型需要一两个月的时间。而训练Large模型,需要花至少四五个月的时间。

花几个月训练一个模型,对于绝大部分在GPU上训练BERT的用户来说真是伤不起。

三,救星——擎天云加速框架为BERT披荆斩棘

 阿里云弹性人工智能团队依托阿里云强大的基础设施资源打磨业内极具竞争力的人工智能创新方案。基于BERT的训练痛点,团队打造了擎天优化版的Perseus-BERT, 极大地提升了BERT pretrain模型的训练速度。在云上一台V100 8卡实例上,只需4天不到即可训练一份BERT模型。

Perseus-BERT是如何打造云上最佳的BERT训练实践?以下干货为您揭秘Perseus-BERT的独门绝技。

1.  Perseus 统一分布式通信框架 —— 赋予BERT分布式训练的轻功

Perseus(擎天)统一分布式通信框架是团队针对人工智能云端训练的痛点,针对阿里云基础设施极致优化的分布式训练框架。其可轻便地嵌入主流人工智能框架的单机训练代码,在保证训练精度的同时高效地提升训练的多机扩展性。擎天分布式框架的干货介绍详见团队另一篇文章《Perseus(擎天):统一深度学习分布式通信框架》

针对tensorflow代码的BERT,Perseus提供horovod的python api方便嵌入BERT预训练代码。基本流程如下:

  • 让每块GPU对应一个Perseus rank进程;
  • 对global step和warmup step做基于rank数的校准;
  • 对训练数据根据rank-id做划分;
  • 给Optimizer增加DistributeOptimizer的wrapper。

值得注意的是,BERT源码用的自定义的Optimizer,在计算梯度时采用了以下api

grads = tf.gradients(loss, tvars)

Perseus的DistributeOptimizer继承标准的Optimizer实现,并在`compute_gradients` api 上实现分布式的梯度更新计算。因此对grads获取做了如下微调

grads_and_vars  = optimizer.compute_gradients(loss, tvars)grads = list()for grad, var in grads_and_vars:grads.append(grad)

 

2.  混合精度训练和XLA编译优化——提升BERT单机性能的内功

混合精度

在深度学习中,混合精度训练指的是float32和float16混合的训练方式,一般的混合精度模式如图6所示

图6】混合精度训练示例。在Forward+Backward计算过程中用float16做计算,在梯度更新时转换为float32做梯度更新。

混合梯度对Bert训练带来如下好处,

  • 增大训练时的batch size和sequence_size以保证模型训练的精度。

      目前阿里云上提供的主流的Nvidia显卡的显存最大为16GB,对一个BERT-Base模型在float32模式只能最高设置为sequence_size=256,batch_size=26。BERT的随机预测模型设计对sequence_size和batch_size的大小有一定要求。为保证匹配BERT的原生训练精度,需要保证sequece_size=512的情况下batch_size不小于16。Float16的混合精度可以保证如上需求。

  • 混合精度能充分利用硬件的加速资源。

      NVidia从Volta架构开始增加了Tensor Core资源,这是专门做4x4矩阵乘法的fp16/fp32混合精度的ASIC加速器,一块V100能提供125T的Tensor Core计算能力,只有在混合精度下计算才能利用上这一块强大的算力。

   受限于float16的表示精度,混合精度训练的代码需要额外的编写,NVidia提供了在Tensorflow下做混合精度训练的教程 。其主要思路是通过tf.variable_scope的custom_getter 参数保证存储的参数为float32并用float16做计算。

   在BERT预训练模型中,为了保证训练的精度,Perseus-BERT没有简单的利用custom_getter参数,而是显式指定训地参数中哪些可以利用float16不会影响精度,哪些必须用float32已保证精度。我们的经验如下:

  • Embedding部分要保证float32精度;
  • Attetion部分可以利用float16加速;
  • Gradients相关的更新和验证需要保证float32精度;
  • 非线性激活等模块需要保证float32精度。

XLA编译器优化

XLA是Tensorflow新近提出的模型编译器,其可以将Graph编译成IR表示,Fuse冗余Ops,并对Ops做了性能优化、适配硬件资源。然而官方的Tensorflow release并不支持xla的分布式训练,为了保证分布式训练可以正常进行和精度,我们自己编译了带有额外patch的tensorflow来支持分布式训练,Perseus-BERT 通过启用XLA编译优化加速训练过程并增加了Batch size大小。

3.  数据集预处理的加速

Perseus BERT 同时对文本预处理做的word embedding和语句划分做了并行化的优化。这里就不展开说明。

 

四,性能——计算时间单位从月降低到天

图7展示了Perseus BERT在P100实例上的性能,与开源主流的horovod相比,Peseus-BERT双机16卡的分布式性能是前者的5倍之多。

目前某大客户已在阿里云P100集群上大规模上线了Perseus BERT,用10台4卡P100只需要2.5天即可训练完成业务模型,如果用开源的horovod(Tensorflow分布式性能优化版)大概需要1个月的时间

【图7】Bert在阿里云上P100实例的对比(实验环境Bert on P100; Batch size: 22 ;Max seq length: 256 ;Data type:float32; Tensorflow 1.12; Perseus: 0.9.1;Horovod: 0.15.2)

 

为了和Google TPU做对比,我们量化了TPU的性能,性能依据如图8。一个Cloud TPU可计算的BERT-Base性能 256 *(1000000/4/4/24/60/60) = 185 exmaples/s。 而一台阿里云上的V100 单机八卡实例在相同的sequence_size=512下, 通过Perseus-BERT优化的Base模型训练可以做到 680 examples/s,接近一台Cloud TPU的4倍性能。对一台Cloud TPU花费16天才能训练完的BERT模型,一台阿里云的V100 8卡实例只需要4天不到便可训练完毕。

【图8】BERT Pretain在Google Cloud TPU上的性能依据

【图8】BERT Pretain在Google Cloud TPU上的性能依据

五,总结——基于阿里云基础设施的AI极致性能优化

弹性人工智能团队一直致力基于阿里云基础设施的AI极致性能优化的创新方案。Perseus-BERT就是一个非常典型的案例,我们在框架层面上基于阿里云的基础设施做深度优化,充分释放阿里云上基础资源的计算能力,让阿里云的客户充分享受云上的AI计算优势,让天下没有难算的AI。


原文链接
本文为云栖社区原创内容,未经允许不得转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/519776.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kubernetes的共享GPU集群调度

问题背景 全球主要的容器集群服务厂商的Kubernetes服务都提供了Nvidia GPU容器调度能力,但是通常都是将一个GPU卡分配给一个容器。这可以实现比较好的隔离性,确保使用GPU的应用不会被其他应用影响;对于深度学习模型训练的场景非常适合&#…

华为云WeLink正式发布,这是更懂企业的智能工作平台一枚!

今日,华为云在京发布智能工作平台WeLink。 对此,华为云副总裁、联接与协同业务总裁薛浩表示:“华为云WeLink源自华为数字化转型实践,是更懂企业的智能工作平台,具备智能高效、安全可靠、开放共赢三大核心优势&#xff…

项目管理过程组和知识领域 简介重点记忆

文章目录一、5大过程组 10个知识领域 49个过程二、10个知识领域 49个过程的输入、工具、输出一、5大过程组 10个知识领域 49个过程 知识领域启动过程组规划过程组执行过程组监控过程组收尾过程组4.项目整合管理4.1 制定项目章程4.2 制定项目管理计划4.3 指导与管理…

一致性协议浅析:从逻辑时钟到Raft

前言 春节在家闲着没事看了几篇论文,把一致性协议的几篇论文都过了一遍。在看这些论文之前,我一直有一些疑惑,比如同样是有Leader和两阶段提交,Zookeeper的ZAB协议和Raft有什么不同,Paxos协议到底要怎样才能用在实际工…

PMP 随堂笔记

CPi挣值管理 临界比值 不属于挣值管理 临界比值 1为分界点 党校与1时,差 大于1时为好 成本激励由有3种场景: 第一种场景:超出目标费用 目标10w 利润1w 分摊比例70/30 实际成本12w 也就是多花了(12w-10w(目标费用)2w 甲方罚乙方利润费用&…

Objective-C中的associated object释放时机问题

如果对象A持有对象B,B作为A的associated object,并且表面上B没有其他被强引用的地方,那么对象A被释放时,对象B一定会同时释放吗?大部分情况下是,但真有不是的时候。最近实现代码的时候不小心就碰到了这样的…

开放共赢,华为云WeLink生态联盟正式成立!

今日,华为在京发布了“更懂企业”的智能工作平台华为云WeLink,并携手合作伙伴成立华为云WeLink生态联盟。其中首批加入华为云WeLink生态联盟的伙伴主要包括(排名不分先后):金山办公、中软国际、致远互联、罗技、华为商…

指明方向与趋势!2019开发者技能报告出炉!!!

近日国外开发者平台 HankerRank 发布了 2019 年开发者技能调查报告,该报告根据对71,281开发者的调查得出。 2018 年最受欢迎的开发语言  经过调查,2018年的所有开发语言中,JavaScript是最受欢迎的语言,2017年最受欢…

webservie报文格式

<![CDATA[ 里面方内容]]>

阿里研究院入选中国企业智库系统影响力榜

2019年2月1日&#xff0c;上海社会科学院智库研究中心发布《2018年中国智库影响力评价与排名》。阿里研究院入围三项排名榜单&#xff0c;位居企业智库系统影响力榜单第2位&#xff0c;中国智库社会影响力榜单第13位&#xff0c;中国智库综合影响力排名榜单第42位。 阿里研究院…

如何给女朋友解释什么是3PC?

戳蓝字“CSDN云计算”关注我们哦&#xff01;一顿愉快的小火锅之后&#xff0c;悠哉悠哉的回家了&#xff0c;于是只能开始新一轮的家庭科普了。分布式一致性幸好在《漫话&#xff1a;如何给女朋友解释什么是2PC&#xff08;二阶段提交&#xff09;&#xff1f;》中介绍过关于2…

Tensorflow源码解析1 -- 内核架构和源码结构

1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的&#xff0c;应用层开发会基于框架层。比如开发Linux Driver会基于Linux kernel&#xff0c;开发Android app会基于Android Framework。深度学习也不例外&#xff0c;框架层为上层模型开发提供了强大的多语言接…

如何在Kubernetes集群动态使用 NAS 持久卷

1. 介绍&#xff1a; 本文介绍的动态生成NAS存储卷的方案&#xff1a;在一个已有文件系统上&#xff0c;自动生成一个目录&#xff0c;这个目录定义为目标存储卷&#xff1b; 镜像地址&#xff1a;registry.cn-hangzhou.aliyuncs.com/acs/alicloud-nas-controller:v1.11.5.4-…

Java循环结构

顺序结构的程序语句只能被执行一次。如果您想要同样的操作执行多次,&#xff0c;就需要使用循环结构。 Java中有三种主要的循环结构&#xff1a; while 循环 do…while 循环 for 循环 一般能够确定循环次数的用for循环 不确定循环次数的用whlie 循环和 do while 循环 do while…

基于 Kubernetes 实践弹性的 CI/CD 系统

大家好,我是来自阿里云容器服务团队的华相。首先简单解释一下何为 Kubernetes 来帮助大家理解。Kuberentes 是一个生产可用的容器编排系统。Kuberentes 一方面在集群中把所有 Node 资源做一个资源池&#xff0c;然后它调度的单元是 Pod&#xff0c;当然 Pod 里面可以有多个容器…

35岁真的是一个坎吗?听完35岁码农的话,我放心了!

戳蓝字“CSDN云计算”关注我们哦&#xff01;之前看过一个有关程序员从刚入职到中年状态的一个视频&#xff0c;刚入职的程序员激情澎湃&#xff0c;一心想做自己想做的事情&#xff0c;并且想创业&#xff0c;就想拉拢身边的程序员同事一起创业&#xff0c;可是身边的同事就一…

开年巨制!千人千面回放技术让你“看到”Flutter用户侧问题

导语 发布app后&#xff0c;开发者最头疼的问题就是如何解决交付后的用户侧问题的还原和定位&#xff0c;是业界缺乏一整套系统的解决方案的空白领域&#xff0c;闲鱼技术团队结合自己业务痛点在flutter上提出一套全新的技术思路解决这个问题。 我们透过系统底层来捕获ui事件流…

Java中数组的打印

数组不能直接打印&#xff0c;打印出来是一个地址&#xff0c;所以编写以下方法用于打印一个数组 public static void printArr(int[] arr) { for (int i : arr) { System.out.print(i" "); } } println 打印一个对象&#xff0c;默认调用这个对象的toString方法