别再蒸馏3层BERT了!变矮又能变瘦的DynaBERT了解一下

一只小狐狸带你解锁炼丹术&NLP秘籍

神经网络模型除了部署在远程服务器之外,也会部署在手机、音响等智能硬件上。比如在自动驾驶的场景下,大部分模型都得放在车上的终端里,不然荒山野岭没有网的时候就尴尬了。对于BERT这类大模型来说,也有部署在终端的需求,但考虑到设备的运算速度和内存大小,是没法部署完整版的,必须对模型进行瘦身压缩。

说到模型压缩,常用的方法有以下几种:

  1. 量化:用FP16或者INT8代替模型参数,一是占用了更少内存,二是接近成倍地提升了计算速度。目前FP16已经很常用了,INT8由于涉及到更多的精度损失还没普及。

  2. 低轶近似/权重共享:低轶近似是用两个更小的矩阵相乘代替一个大矩阵,权重共享是12层transformer共享相同参数。这两种方法都在ALBERT中应用了,对速度基本没有提升,主要是减少了内存占用。但通过ALBRET方式预训练出来的Transformer理论上比BERT中的层更通用,可以直接拿来初始化浅层transformer模型,相当于提升了速度。

  3. 剪枝:通过去掉模型的一部分减少运算。最细粒度为权重剪枝,即将某个连接权重置为0,得到稀疏矩阵;其次为神经元剪枝,去掉矩阵中的一个vector;模型层面则为结构性剪枝,可以是去掉attention、FFN或整个层,典型的工作是LayerDrop[1]。这两种方法都是同时对速度和内存进行优化。

  4. 蒸馏:训练时让小模型学习大模型的泛化能力,预测时只是用小模型。比较有名的工作是DistillBERT[2]和TinyBERT[3]

实际工作中,减少BERT层数+蒸馏是一种常见且有效的提速做法。但由于不同任务对速度的要求不一样,可能任务A可以用6层的BERT,任务B就只能用3层的,因此每次都要花费不少时间对小模型进行调参蒸馏。

有没有办法一次获得多个尺寸的小模型呢?

今天rumor就给大家介绍一篇论文《DynaBERT: Dynamic BERT with Adaptive Width and Depth[4]。论文中作者提出了新的训练算法,同时对不同尺寸的子网络进行训练,通过该方法训练后可以在推理阶段直接对模型裁剪。依靠新的训练算法,本文在效果上超越了众多压缩模型,比如DistillBERT、TinyBERT以及LayerDrop后的模型

Arxiv访问慢的小伙伴也可以在订阅号后台回复关键词【0521】下载论文PDF。

原理

论文对于BERT的压缩流程是这样的:

  • 训练时,对宽度和深度进行裁剪,训练不同的子网络

  • 推理时,根据速度需要直接裁剪,用裁剪后的子网络进行预测

想法其实很简单,但如何能保证更好的效果呢?这就要看炼丹功力了 (..•˘_˘•..),请听我下面道来~

整体的训练分为两个阶段,先进行宽度自适应训练,再进行宽度+深度自适应训练。

宽度自适应 Adaptive Width

宽度自适应的训练流程是:

  1. 得到适合裁剪的teacher模型,并用它初始化student模型

  2. 裁剪得到不同尺寸的子网络作为student模型,对teacher进行蒸馏

最重要的就是如何得到适合裁剪的teacher。先说一下宽度的定义和剪枝方法。Transformer中主要有Multi-head Self-attention(MHA)和Feed Forward Network(FFN)两个模块,为了简化,作者用注意力头的个数和intermediate层神经元的个数来定义MHA和FFN的宽度,并使用同一个缩放系数来剪枝,剪枝后注意力头减小到个,intermediate层神经元减少到个。

在MHA中,我们认为不同的head抽取到了不同的特征,因此每个head的作用和权重肯定也是不同的,intermediate中的神经元连接也是。如果直接按照粗暴裁剪的话,大概率会丢失重要的信息,因此作者想到了一种方法,对head和神经元进行排序,每次剪枝掉不重要的部分,并称这种方法为Netword Rewiring

对于重要程度的计算参考了论文[5],核心思想是计算去掉head之前和之后的loss变化,变化越大则越重要。

利用Rewiring机制,便可以对注意力头和神经元进行排序,得到第一步的teacher模型,如图:

要注意的是,虽然随着参数更新,注意力头和神经元的权重会变化,但teacher模型只初始化一次(在后文有验证增加频率并没带来太大提升)。之后,每个batch会训练四种student模型,如图:

蒸馏的最终loss来源于三方面:logits、embedding和每层的hidden state。

深度自适应 Adaptive Depth

训好了width-adaptive的模型之后,就可以训自适应深度的了。浅层BERT模型的优化其实比较成熟了,主要的技巧就是蒸馏。作者直接使用训好的作为teacher,蒸馏裁剪深度后的小版本BERT。

对于深度,系数,设层的深度为[1,12],作者根据去掉深度为d的层。之所以取是因为研究表明最后一层比较重要[6]

最后,为了避免灾难性遗忘,作者继续对宽度进行剪枝训练,第二阶段的训练方式如图:

实验

根据训练时宽度和深度的裁剪系数,作者最终可得到12个大小不同的BERT模型,在GLUE上的效果如下:

可以看到,剪枝的BERT效果并没有太多下降,并且在9个任务中都超越了BERT-base。同时,这种灵活的训练方式也给BERT本身的效果带来了提升,在与BERT和RoBERTa的对比中都更胜一筹:

另外,作者还和DistillBERT、TinyBERT、LayerDrop进行了实验对比,DynaBERT均获得了更好的效果。

在消融实验中,作者发现在加了rewiring机制后准确率平均提升了2个点之多:

结论

本篇论文的创新点主要在于Adaptive width的训练方式,考虑到后续的裁剪,作者对head和neuron进行了排序,并利用蒸馏让子网络学习大网络的知识。

总体来说还是有些点可以挖的,比如作者为什么选择先对宽度进行自适应,再宽度+深度自适应?这样的好处可能是在第二阶段的蒸馏中学习到宽度自适应过的子网络知识。但直接进行同时训练不可以吗?还是希望作者再验证一下不同顺序的差距。

为了简化,作者在宽度上所做的压缩比较简单,之后可以继续尝试压缩hidden dim。另外,ALBERT相比原始BERT其实更适合浅层Transformer,也可以作为之后的尝试方向。

Arxiv访问慢的小伙伴也可以在订阅号后台回复关键词【0521】下载论文PDF。

参考文献

[1]LayerDrop: https://arxiv.org/abs/1909.11556

[2]DistillBERT: https://arxiv.org/abs/1910.01108
[3]TinyBERT: https://arxiv.org/abs/1909.10351
[4]DynaBERT: https://www.researchgate.net/publication/340523407_DynaBERT_Dynamic_BERT_with_Adaptive_Width_and_Depth
[5]Analyzing multi-head self-attention: https://arxiv.org/abs/1905.09418
[6]Minilm: https://arxiv.org/abs/2002.10957

  • 献给新一代人工智能后浪——《后丹》

  • 搜索中的 Query 理解及应用

  • ICLR认知科学@AI workshop一览

  • All in Linux:一个算法工程师的IDE断奶之路

  • 卖萌屋算法岗面试手册上线!通往面试自由之路

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480499.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LS-GAN:把GAN建立在Lipschitz密度上

最近很多关心深度学习最新进展,特别是生成对抗网络的朋友可能注意到了一种新的GAN-- Wasserstein GAN。其实在WGAN推出的同时,一种新的LS-GAN (Loss Sensitive GAN,损失敏感GAN)也发表在预印本 [1701.06264] Loss-Sensitive Generative Adver…

java程序员必看经典书单,以及各个阶段学习建议

最近,部分读者一直希望我给大家推荐java程序员必读书籍,以及java程序员每个阶段的学习建议。 今天,先给大家推荐1.0版本,后面再不断完善程序员必读书籍2.0版本。 希望,你早日成为牛逼的程序员。 程序员进阶之路 上图是…

数据结构--链表--单链表归并排序mergesort

思路: 1.将链表的中点找到,对其切分成2条 2.继续步骤1,切成4条,8条。。。,直至每段链表只有1个元素 3.归并操作,对两两链表进行合并排序,并返回回并后的链表的头结点,依次向上递归回去 C代码…

我们的实践:事理图谱,下一代知识图谱

原文链接:https://mp.weixin.qq.com/s/iLfXeVeWE5CCs_sM_NAOSw 一、人工智能与认知智能 当前人工智能时代下,机器与人类之间的博弈一直在进行着。如图1所示,从1956年达特茅斯会议的召开标志人工智能诞生到深度学习模型在若干人工智能领域大规…

领域应用 | 偷偷告诉你,那些二次元萌妹都有个叫知识图谱的爸爸

本文转载自公众号:AI 时间。《AI108将》是AI时间全新的AI行业人物专访栏目。艾伦麦席森图灵说:有时,那些人们对他们并不抱有期望的人,却能做到人们不敢期望的事情。Sometimes Its very people who no one imagines angthing of wh…

白话生成对抗网络 GAN,50 行代码玩转 GAN 模型!【附源码】

今天,带大家一起来了解一下如今非常火热的深度学习模型:生成对抗网络(Generate Adversarial Network,GAN)。GAN 非常有趣,我就以最直白的语言来讲解它,最后实现一个简单的 GAN 程序来帮助大家加…

java架构师进阶之独孤九剑(一)-算法思想与经典算法

“ 这是整个架构师连载系列,分为9大步骤,我们现在还在第一个步骤:程序设计和开发->数据结构与算法。 我们今天讲解重点讲解算法。 算法思想 1 贪心思想 顾名思义,贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并…

数据结构--链表--单链表中环的检测,环的入口,环的长度的计算

就如数字6一样的单链表结构,如何检测是否有6下部的○呢,并且求交叉点位置 思路 使用快慢指针(一个一次走2步,一个走1步),若快慢指针第一次相遇,则有环 慢指针路程 sabs absab 快指针路程 2sa…

ACL 2010-2020研究趋势总结

一只小狐狸带你解锁 炼丹术&NLP 秘籍作者:哈工大SCIR 车万翔教授导读2020年5月23日,有幸受邀在中国中文信息学会青年工作委员会主办的AIS(ACL-IJCAI-SIGIR)2020顶会论文预讲会上介绍了ACL会议近年来的研究趋势,特整…

架构师进阶之独孤九剑:设计模式详解

我们继续架构师进阶之独孤九剑进阶,目前我们仍然在第一阶段:程序设计和开发环节。 “ 设计模式不仅仅只是一种规范,更多的是一种设计思路和经验总结,目的只有一个:提高你高质量编码的能力。以下主要分为三个环节&…

知识表示发展史:从一阶谓词逻辑到知识图谱再到事理图谱

研究证实,人类从一出生即开始累积庞大且复杂的数据库,包括各种文字、数字、符码、味道、食物、线条、颜色、公式、声音等,大脑惊人的储存能力使我们累积了海量的资料,这些资料构成了人类的认知知识基础。实验表明,将数…

领域应用 | 基于知识图谱的警用安保机器人大数据分析技术研究

本文转载自公众号:警察技术杂志。 郝久月 樊志英 汪宁 王欣 摘 要:构建大数据支撑下的智能应用是公安信息化发展的趋势,警用安保机器人大数据分析平台的核心功能包括机器人智能人机交互和前…

数据挖掘学习指南!!

入门数据挖掘,必须理论结合实践。本文梳理了数据挖掘知识体系,帮助大家了解和提升在实际场景中的数据分析、特征工程、建模调参和模型融合等技能。完整项目实践(共100多页)后台回复 数据挖掘电子版 获取数据分析探索性数据分析&am…

数据结构--栈--顺序栈/链式栈(附: 字符括号合法配对检测)

栈结构:先进后出,后进先出,像叠盘子一样,先叠的后用。 代码github地址 https://github.com/hitskyer/course/tree/master/dataAlgorithm/chenmingming/stack 1.顺序栈(数组存储,需给定数组大小&#xff0c…

银行计考试-计算机考点2-计算机系统组成与基本工作原理

版权声明&#xff1a;本文为博主原创文章&#xff0c;未经博主允许不得转载。 https://blog.csdn.net/sinat_33363493/article/details/53647129 </div><link rel"stylesheet" href"https://csdnimg.cn/release/pho…

我们的实践: 400万全行业动态事理图谱Demo

历史经验知识在未来预测的应用 华尔街的独角兽Kensho&#xff0c;是智能金融Fintech的一个不得不提的成功案例&#xff0c;这个由高盛领投的6280万美元投资&#xff0c;总融资高达7280万美元的公司自推出后便名声大噪。Warren是kensho是一个代表产品&#xff0c;用户能够以通俗…

蚂蚁花呗团队面试题:LinkedHashMap+SpringCloud+线程锁+分布式

一面 自我介绍 map怎么实现hashcode和equals,为什么重写equals必须重写hashcode 使用过concurrent包下的哪些类&#xff0c;使用场景等等。 concurrentHashMap怎么实现&#xff1f;concurrenthashmap在1.8和1.7里面有什么区别 CountDownLatch、LinkedHashMap、AQS实现原理 …

肖仰华 | SIGIR 2018、WWW2018 知识图谱研究综述

本文转载自公众号&#xff1a;知识工场。全国知识图谱与语义计算大会&#xff08;CCKS: China Conference on Knowledge Graph and Semantic Computing&#xff09;由中国中文信息学会语言与知识计算专委会定期举办的全国年度学术会议。CCKS源于国内两个主要的相关会议&#xf…

数据结构--栈--共享顺序栈

共享顺序栈&#xff1a;内部也是一个数组 将两个栈放在数组的两端&#xff0c;一个从数组首端开始压栈&#xff0c;一个从数组尾部开始压栈&#xff0c;等到两边栈顶在中间相遇时&#xff0c;栈满。 共享顺序栈在某些情况下可以节省空间。 头文件 sharingStack.h //共享顺序…

一个励志PM小哥哥的Java转型之路

先给大家看张我朋友圈截图&#xff1a; 这哥们本科学英语的&#xff0c;毕业后做了产品经理&#xff0c;去年 9 月份开始学 Java&#xff0c;6 个月的时间&#xff0c;拿到了快手的 Offer。如果你对 Java 也有兴趣&#xff0c;不妨听完这个故事。你是不是也和他当时的处境…