pytorch微调bert_小版BERT也能出奇迹:最火的预训练语言库探索小巧之路

选自Medium

作者:Victor Sanh

机器之心编译

参与:魔王

过去一段时间,大模型层出不穷。在大家纷纷感叹「大力出奇迹」的时候,作为调用预训练语言模型最流行的库,HuggingFace 尝试用更少的参数量、更少的训练资源实现同等的性能,于是知识蒸馏版 BERT——DistilBERT 应运而生!

近日,HuggingFace 发布了 NLP transformer 模型——DistilBERT,该模型与 BERT 的架构类似,不过它仅使用了 6600 万参数(区别于 BERT_base 的 1.1 亿参数),却在 GLUE 基准上实现了后者 95% 的性能。

在过去 18 个月中,基于大规模语言模型的迁移学习显著提升了自然语言处理任务的 SOTA 成绩。这些预训练语言模型通常基于 Vaswani 等人提出的 Transformer 架构,这些模型的规模越来越大,训练数据集的规模也越来越大。最近英伟达提出的预训练语言模型拥有 83 亿参数:是 BERT-large 参数量的 24 倍、GPT-2 参数量的 5 倍。而 Facebook AI 最近提出的 RoBERTa 模型在 160GB 文本上训练得到。

6d998ac06812413b9a80f1647b79346b

社区中的一些人质疑训练越来越大 Transformer 的必要性,尤其是考虑到训练的资金成本和环境成本时。该图展示了部分近期大模型及其参数量。

Hugging Face 直接体会到这些模型的流行度,因为其预训练语言库(包含这些模型中的大部分)在近几个月的安装量超过 40 万次。

  • NLP 库地址:https://github.com/huggingface/pytorch-transformers

然而,尽管这些模型被更大的 NLP 社区接受,一个重要且有挑战性的问题出现了。如何将这些庞然大物投入到生产中?如何在低延迟约束下使用这些大模型?我们需要用(昂贵的)GPU 服务器执行大规模服务吗?

755d59fcecae4032bc3c833dc573cb66

为了构建更尊重隐私的系统,Hugging Face 注意到在边缘设备上运行机器学习系统的需求在不断增长,而不是调用云 API,将隐私数据发往服务器。在智能手机等设备上运行的模型需要是轻量级、响应快和能源利用率高的!

最后但同样重要的一点,Hugging Face 越来越担忧这些大模型所需的指数级计算成本增长。

1214cee65166460cac3b6b46138d7d1f

有很多技术可以解决前述问题。最常见的工具是量化(使用更小精度逼近全精度模型)和权重剪枝(移除网络中的部分连接)。想了解更多,可以查看这篇关于 BERT 量化的精彩博客:https://blog.rasa.com/compressing-bert-for-faster-prediction-2/。

Hugging Face 的研究者决定把重点放在知识蒸馏(distillation)上。蒸馏即,将较大模型(教师模型)压缩成较小模型(学生模型)的方法。

知识蒸馏:迁移泛化能力

知识蒸馏是一种模型压缩方法,又叫师生学习(teacher-student learning)。它训练一个小模型,使之复制大模型(或模型集成)的行为。知识蒸馏由 Bucila 等人提出,几年后被 Hinton 等人推广(参见论文《Distilling the Knowledge in a Neural Network》)。Hugging Face 研究者使用的是 Hinton 等人的方法。

在监督学习中,分类模型通常用于预测类别,它利用对数似然信号最大化类别概率。在很多案例中,高性能模型预测的输出分布中,正确的类别具备高概率,而其他类别的概率则接近于零。

df3b357412674f4f9fa39597a988eeae

例如,desk chair(办公椅)可能会被误分类为 armchair(扶手椅),但通常不会被误认为是 mushroom(蘑菇)。这种不确定性被称为「暗知识」。

理解蒸馏的另一种方式是,它阻止模型对预测结果过于自信(类似于标签平滑)。

以下是一个实践示例。在语言建模过程中,我们通过观察词汇分布,可以轻松发现这种不确定性。下图展示了 Bert 对电影《卡萨布兰卡》中某句著名台词的 top 20 补全结果:

92d14a27689c4af08a1525e7f41def9c

BERT_base 对被遮蔽 token 的 top 20 补全结果。该语言模型确定了两个概率较高的 token(day 和 life)。

如何复制暗知识?

在师生训练中,我们训练学生网络来模拟教师网络的完整输出分布(它的知识)。

我们使学生网络和教师网络具备同样的输出分布,从而使学生网络实现同样的泛化。

我们不对硬目标类别(正确类别的 one-hot 编码)使用交叉熵来进行训练,而是对软目标类别(教师网络的概率)执行交叉熵,从而将教师网络的知识迁移到学生网络。这样训练损失函数就变成了:

72951db9a9a24f399691011697dceae8

其中 t 表示教师网络的 logit 值,s 表示学生网络的 logit 值。该损失函数具备更丰富的训练信号,因为软目标类别比单个硬目标类别提供更多约束。

为了进一步揭示类别分布的多样性,Hinton 等人提出了 softmax-temperature:

283bb7a972704a17a1512cf50d5fb6c7

其中T 表示温度参数,当 T → 0 时,分布接近于 one-hot 目标向量,当 T →+∞ 时,则得到均匀分布。在训练过程中对教师网络和学生网络使用同样的温度参数,进而为每一个训练样本提供更多信号。在推断时,T 被设置为 1,恢复标准的 Softmax 函数。

PyTorch 动手实践:压缩 BERT

Hugging Face 研究者想利用知识蒸馏压缩大型语言模型。对于蒸馏,研究者使用 KL 散度作为损失函数,因为最优化过程与交叉熵是等价的:

c40079e1a260402db18fae4d7efe2983

在计算 q(学生网络的分布)的梯度时,得到了同样的梯度。这允许研究者利用 PyTorch 实现执行更快速的计算:

ff0d5b3efb974077858b814958c9bb50

使用教师网络 BERT 的监督信号,研究者训练得到较小的语言模型——DistilBERT。(研究者使用的是 Bert 的英语 bert-base-uncased 版本)。

按照 Hinton 等人的方法,训练损失是蒸馏损失和遮蔽语言建模损失的线性组合。学生模型是 BERT 的较小版本,研究者移除了 token 类型的嵌入和 pooler(用于下一句分类任务),保留了 BERT 的其余架构,不过网络层数只有原版的 1/2。

14ed01b0beea4a91ad4f912ca0de9e94

备注 1:为什么不减少隐藏层大小呢?将它从 768 减少到 512 即可将参数总量减少约 1/2。但是,在现代框架中,大部分运算是经过高度优化的,张量最后一维(隐藏维度)的变化对 Transformer 架构中使用的大部分运算影响较小。在研究者的实验中,相比隐藏层大小,层数才是推断阶段的决定性因素。

研究者的早期实验表明,在该案例中,交叉熵损失会带来更好的性能。因此,他们假设在语言建模设置中,输出空间(词汇)要比下游任务输出空间的维度大得多。而在 L2 损失中,logit 可能会相互抵消。

训练子网络不仅仅关乎架构,它还需要找出子网络收敛的合适初始化(例如彩票假设论文《The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks》)。因此,研究者基于教师网络 Bert 对学生网络 DistilBERT 进行初始化,将层数减半,对学生网络使用与教师网络一样的隐藏层大小。

研究者还使用 RoBERTa 论文中的一些训练 trick,这篇论文证明 Bert 的训练方式对最终性能有着重要影响。遵循 RoBERTa 的训练方式,研究者利用梯度累积以非常大的批次(每个批次多达 4000 个样本)训练 DistilBERT,训练使用了动态遮挡(dynamic masking),并移除了下一句预测目标。

该训练设置主动接受资源方面的限制。研究者在 8 块 16GB V100 GPU 上训练 DistilBERT,训练时长接近三天半,训练数据为 Toronto Book Corpus 和英文维基百科(与原版 BERT 的训练数据相同)。

DistilBert 的部分代码来自于 Facebook XLM 的代码,部分代码来自 Google AI BERT 的 Hugging Face PyTorch 版本。这些代码可在 Hugging Face 的 NLP 库中获取,该库还包含多个 DistilBert 训练版本和微调版本,及其复现代码。

模型性能:DistilBERT 测试

研究者在 GLUE 基准的开发集上对比了 DistilBERT 和两个基线模型的性能,基线模型分别是 BERT base(DistilBERT 的教师模型)和来自纽约大学的强大非 transformer 基线模型:ELMo + BiLSTMs。研究者使用纽约大学发布的 ELMo 基线 jiant 库和 BERT 基线模型的 PyTorch-Transformers 版本。

如下表所示,DistilBERT 在参数量分别是基线模型的 1/2 和 1/3 的情况下,性能可与后者媲美。在 9 项任务中,DistilBERT 的性能通常等同于或优于 ELMo 基线(在 QNLI 任务上的准确率超出后者 14 个百分点)。令人惊讶的是,DistilBERT 的性能堪比 BERT:在参数量比 BERT 少 40% 的情况下,准确率达到了后者的 95%。

74fcc28e8605470cb1b6a5c4dfc0cc3d

在 GLUE 基准开发集上的模型对比结果。ELMo 的性能结果来自原论文,BERT 和 DistilBERT 的性能结果是使用不同种子进行 5 次运行后的中位数

至于推断性能,DistilBERT 的推断速度比 BERT 快 60%,规模也比后者小;DistilBERT 的推断速度比 ELMo+BiLSTM 快 120%,规模也比后者小很多。

为了进一步调查 DistilBERT 的加速/规模权衡(speed-up/size trade-off),研究者对比了每个模型的参数量和在 STS-B 开发集上使用一块 CPU、批大小为 1 的情况下完成一个完整 pass 的推断时间,如上表所示。

下游任务:蒸馏和迁移学习

研究者进一步研究了在高效推断约束下,DistilBERT 在下游任务上的应用。研究者对紧凑的预训练语言模型 DistilBERT 进行微调,用于分类任务。这是结合蒸馏预训练和迁移学习的绝妙方式!

研究者使用 IMDB 评论情感分类数据集,该数据集包含 5 万条英文评论(被标注为积极或消极):其中 2.5 万条作为训练数据,另外一半作为测试数据(均类别均衡)。研究者使用一块 12GB K80 GPU 进行训练。

首先,在数据集上训练 bert-base-uncased。该模型达到了 99.98% 的准确率(3 次运行的平均值),几乎完美!

然后使用同样的超参数训练 DistilBERT。该模型达到了 99.53% 的准确率(3 次运行的平均值),在延迟降低 60%、规模减少 40% 的情况下,DistilBERT 的性能仅比原版 BERT 低 0.5%!

少即是多:小模型也能出奇迹

Hugging Face 对 DistilBERT 的潜力非常看好。DistilBERT 只是个开始,它也提出了许多问题:使用知识蒸馏技术,我们可以把大模型压缩到什么程度?这些技术可用于进一步探索和洞察大模型中存储的知识吗?在压缩过程中损失了语言学/语义学的哪些方面?……

目前,HuggingFace 的这项研究已经开源,并进行知识共享。他们认为这是每个人参与 NLP 进展,并收获最新进展果实的最快和最公平的路径。

  • GitHub 地址:https://github.com/huggingface
  • Medium 页面:http://www.medium.com/huggingface

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/502723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

if test 多条件_秒懂Python编程中的if __name__ == #39;main#39; 作用和原理

在大多数编排得好一点的脚本或者程序里面都有这段if __name__ main:1 这段代码的功能一个python的文件有两种使用的方法:第一是直接作为脚本执行,第二是import到其他的python脚本中被调用(模块重用)执行。因此if __name__ main: 的作用就是控制这两种情…

python背景颜色怎么随机_Python中的随机颜色

我同意TigerhawkT3(1)你教授对pick_color()的实现是垃圾。但我不认为random.choice(),或者你教授滥用random.shuffle()的方式是最好的选择。两者的问题是,在连续调用时可以获得相同的颜色,这是在正方形内绘制正方形时不…

python 解决手机拍的书籍图片发灰的问题

老师给发的作业经常是手机拍的,而不是扫描,背景发灰,如果二次打印就没有看了,象这样: 如果使用photoshop 处理,有些地方还是扣不干净,不如python 做的好,处理后如下: 具体…

2016年cypher资源_2021-2027年中国鱿鱼行业市场供需规模及未来前景分析报告

报告类型:产业研究报告格式:电子版、纸介版、电子纸介出品单位:智研咨询官网链接:中国产业信息网 - 产业前景投资趋势门户-智研旗下产业信息咨询平台​www.chyxx.com报告链接:2021-2027年中国鱿鱼行业市场供需规模及未…

地面控制点的定义与作用_什么是地面塌陷

地面塌陷2020年1月13日,青海西宁市城中区一公交车站附近地面突然塌陷,一辆搭载乘客的公交车掉入坑中,致使9人遇难。2019年12月12日,厦门吕厝路口地铁1号线和2号线外的配套物业开发项目施工现场发生约500平方米地面塌陷&#xff0c…

animate动画案例_animate动画案例——小小购物狂

如今各平台小动画层出不穷,大部分这种二维动画都是animate或者flash做的,例如下面这种效果animate既可以将各种内容做成动画。既可以设计适合游戏、电视节目和 Web 的交互式动画。让卡通和横幅广告栩栩如生。也可以用来创作动画涂鸦和头像。并向电子学习…

男孩子不上学了学计算机要学历吗,十三岁男孩不上学,能学什么手艺?

十三岁男孩不上学,能学什么手艺?十三岁时的孩子,有些学校要求我们先上过义务教育再去学习,有些学校是允许十三岁就直接接受教育的,有些学校是对十三岁还在上半学的学生进行补习一下的。那么,十三岁男孩不上学,可以学什么手艺?其实,有很多孩子对自己在学校学习时未能掌握的知识…

numpy 拼接_数据分析-numpy的拼接与交换

1.数组的拼接import numpy as npt1np.arange(24).reshape((4,6))t2np.arange(100,124).reshape((4,6))print(t1)print("*"*50)print(t2)print("*"*50)#竖直拼接t3np.vstack((t1,t2))print(t3)print("*"*50)#水平拼接t4np.hstack((t1,t2))print(t…

iptables 指定网卡_LINUX系统下的IPTABLES防火墙系统讲解(二)实战操作

iptables数据流方向iptables操作命令:#iptables --helpUsage: iptables -[AD] chain rule-specification [options]iptables -[RI] chain rulenum rule-specification [options]iptables -D chain rulenum [options]iptables -[LFZ] [chain] [options]iptables -[NX] chainipta…

java接口文档生成工具_接口文档生成

一、为什么要写接口文档?1.正规的团队合作或者是项目对接,接口文档是非常重要的,一般接口文档都是通过开发人员写的。一个工整的文档显得是非重要。2.项目开发过程中前后端工程师有一个统一的文件进行沟通交流开发,项目维护中或者…

联想计算机如何设置用户名和密码忘了,联想(Lenovo)路由器无线wifi密码忘记了怎么办啊?...

联想(Lenovo)路由器无线wifi密码忘记了怎么办?忘记wifi密码这个问题,很多用户都会遇到。因为手机、笔记本、平板电脑在首次连接wifi信号后,会自动保存该wifi信号密码,以后会自动进行连接,无需用户手动输入wifi密码&…

mysql binlog查看_MySQL--17 配置binlog-server 及中间件

配置binlog-server修改mha配置文件[rootmysql-db03 ~]# vim /etc/mha/app1.cnf[binlog1]no_master1hostname10.0.0.53master_binlog_dir/data/mysql/binlog/备份binlog#创建备份binlog目录[rootmysql-db03 ~]# mkdir -p /data/mysql/binlog/#进入该目录[rootmysql-db03 ~]# cd …

桥梁在线计算机监测系统,桥梁在线监测系统

原标题:桥梁在线监测系统监测背景我国是个桥梁大国,据最新数据统计,我国超过100万座公路桥梁(不含市政桥梁)。影响桥梁的因素居多,人为因素、车辆长期超载、材料自身退化等,缺乏及时到位的管理养护导致结构各部分在远没…

idea黑色好还是白色好_白色牛仔裤,好看又好搭

白色是属于夏天的颜色,也是最纯洁、最惹人注目的颜色。无论时尚如何轮回迭代,白色给我们的代名词永远是优雅、高贵、纯洁、干净、高贵、永恒等这些美好的词汇。白色是时光、流动、轻巧的颜色,它代表着东方的安静和中庸,也是留白含…

c 子类对象 访问父类对象受保护成员_面向对象编程(OOP)

这节讲一下,什么是面向对象(Object Oriented Programming)。说面向对象之前,我们不得不提的是面向过程(Process Oriented Programming),C语言就是面向过程的语言,这两者的区别在哪呢?我们可以设想一个情景——厨房做菜…

linux数据泵导入command not found_MySQL:数据库结构优化、高可用架构设计、数据库索引优化...

一、SQL查询优化(重要)1.1 获取有性能问题SQL的三种方式通过用户反馈获取存在性能问题的SQL;通过慢查日志获取存在性能问题的SQL;实时获取存在性能问题的SQL;1.1.2 慢查日志分析工具相关配置参数:slow_query_log # 启动停止记录慢…

武汉船舶职业技术学院计算机分数线,武汉船舶职业技术学院录取分数线2021是多少分(附历年录取分数线)...

武汉船舶职业技术学院录取分数线2020是多少分,各专业录取分数线是多少,是每个填报武汉船舶职业技术学院的考生最关注的问题,随着各省高考录取批次相继公布,考生也开始关心是否被录取,本站小编整理相关信息供参考&#…

linux std::queue 怎么释放内存_电脑卡慢怎么办?一个小工具帮你轻松释放内存,瞬间提升电脑性能...

有一种电脑叫“卡巴死机”大家有没有发现,如今的电子产品越来越不耐用了。无论是电脑,还是手机,超过一年以上,就得考虑更换了。1G变2G,2G升4G,按理说电脑应该会更快更好,实际却是相反&#xff0…

2015计算机二级公共基础知识,2015年计算机二级公共基础知识考点测试题(8)

排序技术1[单选题]对长度n的线性表排序,在最坏情况下,比较次数不是n(n一1)/2的排序方法是(  )。参考答案:D参考解析:排序技术有:①交换类排序法(冒泡排序法、快速排序法);②插入类排序法(简单插入排序、希尔排序);③…

2020年周数和日期对应表_2020年雅思考试报名截止日期、准考证打印日期和成绩单寄送日期...

2020年雅思考试报名截止日期、准考证打印日期和成绩单寄送日期考试日期类别口试预定 开始日期*报名截止日期准考证 打印日期成绩单 寄送日期*04/01/2020A14/12/201916/12/201925/12/201917/01/202011/01/2020A+G21/12/201923/12/201901/01/202031/01/202016/01/2020…