训练softmax分类器实例_知识蒸馏:如何用一个神经网络训练另一个神经网络

公众号关注 “ML_NLP”设为 “星标”,重磅干货,第一时间送达!

e86023434cdf6f084708d2c5648e5254.png


转载自:AI公园

作者:Tivadar Danka

编译:ronghuaiyang

导读

知识蒸馏的简单介绍,让大家了解知识蒸馏背后的直觉。

如果你曾经用神经网络来解决一个复杂的问题,你就会知道它们的尺寸可能非常巨大,包含数百万个参数。例如著名的BERT模型约有1亿1千万参数。

为了说明这一点,参见下图中的NLP中最常见架构的参数数量。

818c0d1a508db81ef56c29a85844502d.png

各种模型结构的参数数量

在Kaggle竞赛中,胜出的模型通常是由几个模型组成的集合。尽管它们在精确度上可以大大超过简单模型,但其巨大的计算成本使它们在实际应用中完全无法使用。

有没有什么方法可以在不扩展硬件的情况下利用这些强大但庞大的模型来训练最先进的模型?

目前,有三种方法可以压缩神经网络,同时保持预测性能:

  • 权值裁剪
  • 量化
  • 知识蒸馏

在这篇文章中,我的目标是向你介绍“知识蒸馏”的基本原理,这是一个令人难以置信的令人兴奋的想法,它的基础是训练一个较小的网络来逼近大的网络。

什么是知识蒸馏?

让我们想象一个非常复杂的任务,比如对数千个类进行图像分类。通常,你不能指望ResNet50能达到99%的准确度。所以,你建立一个模型集合,平衡每个模型的缺陷。现在你有了一个巨大的模型,尽管它的性能非常出色,但无法将其部署到生产环境中,并在合理的时间内获得预测。

然而,该模型可以很好地概括未见的数据,因此可以放心地相信它的预测。(我知道,情况可能不是这样的,但我们现在就开始进行思维实验吧。)

如果我们使用来自大而笨重的模型的预测来训练一个更小的,所谓的“学生”模型来逼近大模型会怎么样?

这本质上就是知识的蒸馏,这是由Geoffrey Hinton、Oriol Vinyals和Jeff Dean在论文Distilling the Knowledge in a Neural Network中介绍的。

大致说来,过程如下。

  1. 训练一个能够性能很好泛化也很好的大模型。这被称为教师模型
  2. 利用你所拥有的所有数据,计算出教师模型的预测。带有这些预测的全部数据集被称为知识,预测本身通常被称为soft targets。这是知识蒸馏步骤。
  3. 利用先前获得的知识来训练较小的网络,称为学生模型

为了使过程可视化,见下图。

146acc4652aa6f1de970756045a1dce0.png

知识蒸馏

让我们关注一下细节。知识是如何获得的?

在分类器模型中,类的概率由softmax层给出,将logits转换为概率:

4270b21a175a9b3695bd895a961f1a25.png

其中:

8b44e02cf0d14c5d6ed25ee7a860ca7d.png

是最后一层生成的logits。替换一下,得到一个稍有修改的版本:

2d279c7030e91b2ae771b1a856c35210.png

其中,T是一个超参数,称为温度。这些值叫做soft targets。

如果T变大,类别概率会变软,也就是说会相互之间更加接近,极端情况下,T趋向于无穷大。

d7f0e68675435dc2f60aefecc5555213.png

如果T = 1,就是原来的softmax函数。出于我们的目的,T被设置为大于1,因此叫做蒸馏

Hinton, Vinyals和Dean证明了一个经过蒸馏的模型可以像由10个大型模型的集成一样出色。

222ce92557afd1908cd9bb467163f8d6.png

Geoffrey Hinton, Oriol Vinyals和Jeff Dean的论文Distilling the Knowledge in a Neural Network中对一个语音识别问题的知识蒸馏的结果

为什么不重头训练一个小网络?

你可能会问,为什么不从一开始就训练一个更小的网络呢?这不是更容易吗?当然,但这并不一定有效。

实验结果表明,参数越多,泛化效果越好,收敛速度越快。例如,Sanjeev Arora, Nadav Cohen和Elad Hazan在他们的论文“On the Optimization of Deep Networks: Implicit Acceleration by Overparameterization”中对此进行了研究。

bbbd9a8766040a2fa17ca594c5746f5f.png

左:单层网络与4层和8层的线性网络。右:使用TensorFlow教程中的MNIST分类的参数化和基线模型。

对于复杂的问题,简单的模型很难在给定的训练数据上很好地泛化。然而,我们拥有的远不止训练数据:教师模型对所有可用数据的预测。

这对我们有两方面的好处。

  • 首先,教师模型的知识可以教学生模型如何通过训练数据集之外的可用预测进行泛化。回想一下,我们使用教师模型对所有可用数据的预测来训练学生模型,而不是原始的训练数据集。
  • 其次,soft targets提供了比类标签更有用的信息:它表明两个类是否彼此相似。例如,如果任务是分类狗的品种,像“柴犬和秋田犬非常相似”这样的信息对于模型泛化是非常有价值的。

55c359669da83735c8c17a904c520aec.png

左:秋田犬,右:柴犬

与迁移学习的区别

Hinton等人也提到,最早的尝试是复用训练好的集成模型中的一些层来迁移知识,从而压缩模型。

用Hinton等人的话来说,

“……我们倾向于用学习的参数值在训练过的模型中识别知识,这使得我们很难看到如何改变模型的形式而保持相同的知识。知识的一个更抽象的观点是,它是一个从输入向量到输出向量的学习好的映射,它将知识从任何特定的实例化中解放出来。—— Distilling the Knowledge in a Neural Network

因此,与转移学习相反,知识蒸馏不会直接使用学到的权重。

使用决策树

如果你想进一步压缩模型,你可以尝试使用更简单的模型,如决策树。尽管它们的表达能力不如神经网络,但它们的预测可以通过单独观察节点来解释。

这是由Nicholas Frosst和Geoffrey Hinton完成的,他们在他们的论文Distilling a Neural Network Into a Soft Decision Tree中对此进行了研究。

391c5e12955857f4f0d217048877b9cb.png

他们的研究表明,尽管更简单的神经网络的表现比他们的研究要好,但蒸馏确实起到了一点作用。在MNIST数据集上,经过蒸馏的决策树模型的测试准确率达到96.76%,较基线模型的94.34%有所提高。然而,一个简单的两层深卷积网络仍然达到了99.21%的准确率。因此,在性能和可解释性之间存在权衡。

Distilling BERT

到目前为止,我们只看到了理论结果,没有实际的例子。为了改变这种情况,让我们考虑近年来最流行和最有用的模型之一:BERT。

来自于谷歌的Jacob Devlin等人的论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding,很快被广泛应用于各种NLP任务,如文档检索或情绪分析。这是一个真正的突破,推动了几个领域的技术发展。

然而,有一个问题。BERT包含约1.1亿个参数,需要大量的时间来训练。作者报告说,训练需要4天,使用4个pods中的16个TPU芯片。训练成本将约为10000美元,不包括碳排放等环境成本。

Hugging Face成功地尝试减小BERT的尺寸和计算成本。他们使用知识蒸馏来训练DistilBERT,这是原始模型大小的60%,同时速度提高了60%,语言理解能力保持在97%。

597c8039a8448e23551049077cc1e474.png

DistilBERT的性能。

较小的架构需要更少的时间和计算资源:在8个16GB V100 gpu上花费90小时。如果你对更多的细节感兴趣,你可以阅读原始论文"DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter"或者文章的综述,写的很精彩,强烈推荐。

总结

知识蒸馏是压缩神经网络的三种主要方法之一,使其适合于性能较弱的硬件。

与其他两种强大的压缩方法权值剪枝和量化不同,知识蒸馏不直接对网络进行缩减。相反,它使用最初的模型来训练一个更小的模型,称为“学生模型”。由于教师模型甚至可以对未标记的数据提供预测,因此学生模型可以学习如何像教师那样进行泛化。在这里,我们看到了两个关键的结果:最初的论文,它介绍了这个想法,和一个后续的论文,展示了简单的模型,如决策树,也可以用作学生模型。

下载1:四件套

在机器学习算法与自然语言处理公众号后台回复“四件套”

即可获取学习TensorFlow,Pytorch,机器学习,深度学习四件套!

f3cdb96aeecce712ebd343bf6077acdd.png

下载2:仓库地址共享

在机器学习算法与自然语言处理公众号后台回复“代码”

即可获取195篇NAACL+295篇ACL2019有代码开源的论文。开源地址如下:https://github.com/yizhen20133868/NLP-Conferences-Code

重磅!机器学习算法与自然语言处理交流群已正式成立

群内有大量资源,欢迎大家进群学习!

额外赠送福利资源!邱锡鹏深度学习与神经网络,pytorch官方中文教程,利用Python进行数据分析,机器学习学习笔记,pandas官方文档中文版,effective java(中文版)等20项福利资源

643b477c4958f231e65b978633a320e9.png

获取方式:进入群后点开群公告即可领取下载链接

注意:请大家添加时修改备注为 [学校/公司 + 姓名 + 方向]

例如 —— 哈工大+张三+对话系统。

号主,微商请自觉绕道。谢谢!

6425dec160b7429cd640007196595d61.png

33303e8ec8afd81c4a68821f1e8c502b.png

推荐阅读:

开放域知识库问答研究回顾

使用PyTorch Lightning自动训练你的深度神经网络

PyTorch常用代码段合集

834150b01cf34ccbb11722e9b21ea2a4.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/522418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10个业界最流行的Kubernetes发行版

戳蓝字“CSDN云计算”关注我们哦!作者 | Serdar来源 | RancherLabs如果你需要大规模的容器编排,想必Kubernetes毋庸置疑是你的首要选择,这一由谷歌推出的开源容器编排系统近年来发展飞速,大受业界及广大用户好评。尽管如此&#x…

神经进化是深度学习的未来

摘要: 本文主要讲了神经进化是深度学习的未来,以及如何用进化计算方法(EC)优化深度学习(DL)。过去几年时间里,我们有一个完整的团队致力于人工智能研究和实验。该团队专注于开发新的进化计算方法…

深度学习的关键术语

摘要: 本文着重介绍了深度学习的一些关键术语,其中包括生物神经元,多层感知器(MLP),前馈神经网络和递归神经网络。对于初学者来说,掌握它们可以防止在学习请教时的尴尬~深度学习已经成为编程界的…

虚拟化精华问答 | 虚拟化技术分类

虚拟化是一种资源管理技术, 是将计算机的各种物理资源, 如服务器、网络、内存及存储等,予以抽象、转换后呈现出来,打破物理设备结构间的不可切割的障碍,使用户可以比原本的架构更好的方式来应用这些资源。这些资源的虚拟部分是不受现有资源的…

远程服务器 上传公钥,SecureCRT+Ubuntu SSH服务器的远程公钥登陆

有耐心地往下看,哥是实现了的,并且所有细节会给的相当的丰富哈。Ubuntu: Ubuntu 14.04 LTSopensshWindow10(64位):SecureCRT8.0看网上的列为同牛们说gitssh用,自己搭建git服务器,so嗨,所以行动起来,先给win…

理解卷积神经网络的利器:9篇重要的深度学习论文(上)

摘要: 为了更好地帮助你理解卷积神经网络,在这里,我总结了计算机视觉和卷积神经网络领域内许多新的重要进步及有关论文。手把手教你理解卷积神经网络(一)手把手教你理解卷积神经网络(二)本文将介绍过去五年内发表的一些重要论文,并…

理解卷积神经网络的利器:9篇重要的深度学习论文(下)

摘要: 为了更好地帮助你理解卷积神经网络,在这里,我总结了计算机视觉和卷积神经网络领域内许多新的重要进步及有关论文。手把手教你理解卷积神经网络(一)手把手教你理解卷积神经网络(二)继“理解卷积神经网络的利器:9篇重要的深度…

工作流实战篇_01_flowable 流程Demo案例

由于群里有些朋友对这个flowable还不是 很熟悉,在群里的小伙伴的建议下,师傅(小学生05101)制作一个开源的项目源码,一共大家学习和交流,希望对有帮助,少走弯路 如果有不懂的问题可以入群:633168411 里面都是…

antd 进行ajax请求,react+dva+antd接口调用方式

一丶 安装通过 npm 安装 dva-cli 并确保版本是0.8.1或以上。$ npm install dva-cli -g$ dva -v0.8.1二丶创建新应用安装完dva-cli之后,就可以在命令行里访问到dva命令(不能访问?)。现在,你可以通过dva new创建新应用。$ dva new dva-quicksta…

基于MaxCompute的拉链表设计

摘要: 简单的拉链表设计 背景信息: 在数据仓库的数据模型设计过程中,经常会遇到这样的需求: 数据量比较大; 表中的部分字段会被update,如用户的地址,产品的描述信息,订单的状态、手机号码等等; 需要查看…

2019全球编程语言高薪排行榜登场;余承东正式宣布华为IFA2019 或发布麒麟990;OPPO、vivo和小米成立互传联盟…...

关注并标星星CSDN云计算极客头条:速递、最新、绝对有料。这里有企业新动、这里有业界要闻,打起十二分精神,紧跟fashion你可以的!每周三次,打卡即read更快、更全了解泛云圈精彩newsgo go go 全新的索尼PS5(图…

python文件输出log_Python同时向控制台和文件输出日志logging的方法

#-*- coding:utf-8 -*- import logging # 配置日志信息 logging.basicConfig(levellogging.DEBUG, format%(asctime)s %(name)-12s %(levelname)-8s %(message)s, datefmt%m-%d %H:%M, filenamemyapp.log, filemodew) # 定义一个Handler打印INFO及以上级别的日志到sys.stderr c…

MaxCompute使用常见问题总结

摘要: Maxcompute常见问题的总结,方便广大用户可以快速排查问题 计费相关 存储计费:按照存储在 MaxCompute 的数据的容量大小进行阶梯计费。 计算计费:MaxCompute 分按量后付费和按 CU 预付费两种计算计费方式。 按量后付费&#…

工作流实战_02_flowable 流程模板导入

由于群里有些朋友对这个flowable还不是很熟悉,在群里的小伙伴的建议下,师傅(小学生05101)制作一个开源的项目源码,一共大家学习和交流,希望对有帮助,少走弯路 如果有不懂的问题可以入群:633168411 里面都是…

华为服务器raid1装系统,服务器raid1系统安装

服务器raid1系统安装 内容精选换一换需要创建两台ECS,一台使用Linux系统安装SAP应用与DB2,另外一台用于安装SAP GUI和作为跳板机,两台ECS详情如下所示,下表均为示例,请根据实际情况购买Avago 3408iMR RAID卡不支持虚拟…

关于大数据你应该了解的五件事儿

摘要: 本文从基本概念、行业趋势、学习途径等几个方面介绍了大数据的相关内容,适合对大数据感兴趣的读者作为入门材料阅读。随着科技的发展,目前已经步入了大数据的时代,很多社交媒体和互联网公司也非常关注大数据这一行业。那么对…

当我们谈AI时,到底该谈什么?

报名倒计时仅剩1天,即刻扫描下方二维码,或者点击【阅读原文】免费报名,让我们不见不散。

前端电子表数字字体_爬虫:如何优雅应对字体反爬

目录THE BEGIN一 什么是字体反爬二 如何解密1.人工解密2.工具解密三 建立映射关系四 解密THE BEGIN网页数据爬取可以简单分为三步:抓取页面,分析页面,存储数据。其中第一二步最为头疼,因为每个站点各有特色,你要不断检…

ECS云资源可视化--资源概览

摘要: 随着越来越多的业务接入云计算,云上拥有的各类资源也越来越多,用户如何时时对其拥有的各类资源进行统计分析成为一个难题。ECS控制台针对这一问题,推出资源概览功能,目前支持实例和存储两种云资源的统计和分析功…

工作流实战_03_flowable 流程模板部署

由于群里有些朋友对这个flowable还不是 很熟悉,在群里的小伙伴的建议下,师傅(小学生05101)制作一个开源的项目源码,一共大家学习和交流,希望对有帮助,少走弯路 如果有不懂的问题可以入群:633168411 里面都是…