深度学习三大谜团:集成、知识蒸馏和自蒸馏

深度学习三大谜团:集成、知识蒸馏和自蒸馏

转自:https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA

原文(英):https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge-distillation-and-self-distillation/

集成(Ensemble,又称模型平均)是一种"古老"而强大的方法。只需要对同一个训练数据集上,几个独立训练的神经网络的输出,简单地求平均,便可以获得比原有模型更高的性能。甚至只要这些模型初始化条件不同,即使拥有相同的架构,集成方法依然能够将性能显著提升。

但是,为什么只是简单的"集成”,便能提升性能呢?

在这里插入图片描述

目前已有的理论解释大多只能适用于以下几种情况:

(1)boosting:模型之间的组合系数是训练出来的,而不能简单地取平均;

(2)Bootstrap aggregation:每个模型的训练数据集都不相同;

(3)每个模型的类型和体系架构都不相同;

(4)随机特征或决策树的集合。

但正如上面提到,在(1)模型系数只是简单的求平均;(2)训练数据集完全相同;(3)每个模型架构完全相同 下,集成的方法都能够做到性能提升。

在这里插入图片描述

论文链接:https://arxiv.org/pdf/2012.09816.pdf

来自微软研究院机器学习与优化组的高级研究员朱泽园博士,以及卡内基梅隆大学机器学习系助理教授李远志针对这一现象,在最新发表的论文**《在深度学习中理解集成,知识蒸馏和自蒸馏》**(Towards Understanding Ensemble, Knowledge Distillation, and Self-Distillation in Deep Learning)中,提出了一个理论问题:

在这里插入图片描述

当我们简单地对几个独立训练的神经网络求平均值时,“集成”是如何改善深度学习的测试性能的?尤其是当所有神经网络具有相同的体系结构,使用相同的标准训练算法(即具有相同学习率和样本正则化的随机梯度下降),在相同数据集上进行训练时,即使所有单个模型都已经进行了100%训练准确性?随后,将集合的这种优越性能“蒸馏”到相同架构的单个神经网络,为何能够保持性能基本不变?

两位作者分别从理论和实验的角度给出了分析结果:

原因在于数据集中“多视图”(Multi-view)数据的存在。

1、深度学习的三大谜团

谜团 1:集成

观察结果显示,使用不同随机种子的学习网络 F1,…F10F_1,\dots F_{10}F1,F10(尽管具有非常相似的测试性能)相关联的函数非常不同。在这种情况下,使用“集成”的技术,仅需要获取这些经过独立训练的网络输出的未加权平均值,就可以在许多深度学习应用中极大地提高测试时间的性能。(参见图1)这意味着各个函数F1,…F10F_1,\dots F_{10}F1,F10一定是不同的。但是,为什么集成可以大幅提升性能呢?

如果直接训练 (F1+⋯+F10)/10(F_1+\dots+F_{10})/10(F1++F10)/10,为什么性能提升就消失了?

在这里插入图片描述

图1:集成(Ensemble)提升了深度学习应用中的测试准确性,但是这种准确性的提高则无法通过直接训练模型的平均值来实现。

谜团2:知识蒸馏

虽然集成可以极大地提升测试时间性能,但在推理时间(即测试时间)方面它变慢了10倍:我们需要计算10个神经网络的输出,而不是1个。当我们在低能耗的移动环境中部署此类模型时,这是一个严重的问题。

为了解决这个问题,研究者提出了一种叫做知识蒸馏的开创性技术。知识蒸馏指的是训练另一个单独的模型来匹配集成的输出。在这里,一张猫的图像上的集成(也称为隐藏知识)输出可能看起来像“ 80%猫+ 10%狗+ 10%汽车”,而真正的训练标签是“ 100%猫”。(请参见下面的图2)

事实证明,经过这样训练的单个模型可以在很大程度上匹配10倍以上集成模型的测试时间性能。但是,这导致了更多的问题。

与匹配真实标签相比,为什么匹配集成模型的输出可以为我们提供更好的测试准确性?此外,我们在知识蒸馏后对模型进行集成学习可以进一步提高测试准确性吗?

在这里插入图片描述

图2:知识蒸馏和自蒸馏也能够提升深度学习的性能。

谜团3: 自蒸馏。

注意,知识蒸馏至少在直观上是有意义的:teacher model有84.8% 的测试准确率,那么student model 可以达到83.8% 。

但接下来这个现象就让人难以理解了,使用自蒸馏技术,也即老师的学生就是它自己:通过对具有相同架构的单个模型进行知识蒸馏,竟然可以提高测试准确率。想象一下: 训练出一个测试准确率为81.5% 的单个模型,结果使用相同结构的模型进行自蒸馏一下,测试准确率竟然提高到了83.5%,这不是很奇怪么?

2、神经网络集成与特征映射集成

大多数现有的集成理论只适用于单个模型之间存在根本性差异的情况(例如,决策树支持不同变量的子集),或者在不同的数据集上训练的情况(例如自举)。

但这些理论显然不能解释前面提到的现象。上面提到,集成的模型,其训练的框架是相同的,训练的数据也是相同的 —— 唯一的区别只是训练期间的随机性。

或许,与深度学习中的集成最为相近的理论,应该是“随机特征映射集成”(ensemble in random feature mappings)。这表现在两个方面:一方面,将多个随机特征的线性模型进行组合,可以提升测试时的性能,这很显然,因为它增加了特征的数量;另一方面,在特定的参数区域中,神经网络的权值可以非常接近它们的初始化(称为神经正切核区域,或 NTK区域) ,结果网络只在规定的特征映射上学习一个线性函数,这些特征映射完全由随机初始化决定。

将这两者结合起来,可以推测深度学习集成与随机特征映射集成,在原理上是一致的。

这就引出了另外一个问题:集成和知识蒸馏在深度学习上,与在随机特征映射(即NTK特征映射)上,是否会有相同的表现呢?

答案是否定的。

如下图3所示,该图比较了在深度学习/随机特征映射中的集成和知识蒸馏的性能。

在这里插入图片描述

图3: 集成在随机特征映射上有效(但是出于与深度学习完全不同的原因) ,而知识蒸馏在随机特征映射中不起作用。

可以看出,通过集成的方式,无论是在深度学习中,还是在随机特征映射中,都能够得到较好的性能;而在随机特征映射中,知识蒸馏的性能显然要比单个模型的性能还要差。

这就很明显地说明:集成和蒸馏,原理上并不相同。

具体来说:与在深度学习情况不同,在随机特征映射中,集成的优越性能不能蒸馏到单个模型上。

在图3中,神经正切核(NTK)模型的集成,在 CIFAR-10数据集上达到了70.54%的准确率,但经过知识蒸馏后,它下降到了66.01% ,甚至比单个模型的66.68% 的测试准确率还要低。

在深度学习中,直接训练模型的平均值 (𝐹1+⋯+𝐹10)/10(𝐹_1+\dots+𝐹_{10})/10(F1++F10)/10 与训练单个模型 𝐹𝑖𝐹_𝑖Fi 相比没有任何优势;而在随机特征映射中,训练平均值的效果优于单个模型及其集成。

在图3中,NTK 模型的集成的准确率为 70.54% ,而直接训练10个模型的平均值准确率为72.86%。

为什么会这样呢?

主要原因在于,神经网络是使用分层特征学习,尽管每个模型 𝐹𝑖𝐹_𝑖Fi 有不同的初始化,但在每一层它们都拥有相同的特征集合。因此,与单个模型相比,多个模型的平均模型,并没有增加其特征集合的大小。

在随机特征映射中,每个 𝐹𝑖𝐹_𝑖Fi 都使用了一组完全不同的规定特征。因此,无论是使用集成的方式,还是直接求平均的方式,都能够带来一些性能优势,但由于特征的稀缺性,在蒸馏后,性能必然会有一定下降。

3、集成与减少单个模型的方差

除了随机特征的集成外,还有人推测认为,由于神经网络的高度复杂性,每个单独的模型 𝐹𝑖𝐹_𝑖Fi 可能学习到一个函数 𝐹𝑖(𝑥)=𝑦+ξ𝑖𝐹_𝑖(𝑥)=𝑦+ξ_𝑖Fi(x)=y+ξiξ𝑖ξ_𝑖ξi 是某种噪声,这种噪声取决于训练过程中使用的随机性。

经典的统计学认为,如果所有的 ξ𝑖ξ_𝑖ξi 是大致独立的,那么求取他们的平均值能够大大减少噪音量。

因此,“集成能够减少方差”真的是集成能提高提高性能的原因吗?

证据表明,在深度学习的背景下,这种减少方差来提升性能的假设是值得怀疑的:

1. 集成并不能无限制地提高测试的准确性。

集成超过100个单个模型通常与集成10个单个模型基本没有差别。因此,100 ξ𝑖ξ_𝑖ξi 的平均值与10 ξ𝑖ξ_𝑖ξi 的平均值相比,方差不再减小,表明 ξ𝑖ξ_𝑖ξi 可能是不独立的,而且有可能存在偏差,因此均值不为零。在 ξ𝑖ξ_𝑖ξi 不独立的情况下,很难讨论求得这些 ξ𝑖ξ_𝑖ξi 的平均值能够减少多少偏差。

2. 即使理想情况下,我们认为ξ𝑖ξ_𝑖ξi是相互独立的,那么这就表明ξ𝑖ξ_𝑖ξi是有偏或异号的。

于是我们可以将 𝐹𝑖𝐹_𝑖Fi 写成:
𝐹𝑖(x)=𝑦+ξ+ξ𝑖𝐹_𝑖(x)=𝑦+ξ+ξ_𝑖 Fi(x)=y+ξ+ξi
ξξξ 是一个固定误差,ξ𝑖ξ_𝑖ξi 则指每个模型的独立误差。于是在集成之后,期望的网络输出将接近 y+ξy + ξy+ξ,这会有一个固定的偏差 ξξξ

在这种情况下,为什么知识蒸馏会有效呢?那么,为什么这个带有偏差 ξξξ (也被称为隐藏知识)的输出会优于原来的训练呢?

3. 集成学习并不总是能够提高准确性

在图4中,我们可以看到神经网络的集成学习并不总是能够提高测试的准确性,至少在输入类似高斯分布的情况下是这样。换句话说,在这些网络中,求平均值不会带来任何准确性的增益。

综上来看,我们需要更深入地理解深度学习中的集成,而不只是认为“集成能够减少方差”这么简单。

在这里插入图片描述

图4: 当输入类似高斯分布时,实验表明集成并不能提高测试的准确性。

4、多视图数据:深度学习中集成的一种新方法

图4表明,在非结构化随机输入的情况下,集成并不凑效。在我们最新的工作中,我们从数据中找到了集成之所以能够在深度学习中有效的原因所在。

通常,在一个数据集中(以视觉数据集为例),一个对象通常会有多个视角(muti-view)的数据。以“car”为例,一个汽车的数据集中,通常会有从各个角度拍摄的车辆的照片,通常我们仅需要通过车头灯、车轮或车窗等其中的一个特征,便可以对汽车进行分类了;即使在图片中有些特征因为拍摄角度的原因而缺失了,也没有太大的关系。例如从正前方拍摄的汽车,图像中便没有车轮,但这并不妨碍我们识别出“car"。

在这里插入图片描述

图5: 在CIFAR-10数据集上进行训练的 ResNet-34第23层的一些通道的可视化

这种现象在多数数据中都会存在,其中每类数据都具有多个视角的特征,这种结构被称为“多视图”(multi-view)。

在大多数数据中,几乎所有的视图特征都会显示出来;但在某些数据中,却可能缺少一些视图特征。

更广泛地说,这种“多视图”结构事实上,不仅在原始数据中存在,在中间层抽取的特征集合中也会存在。

在这种“多视图”结构下进行训练,网络会:

1)根据学习过程中的随机性,快速学习这些视图特征的一个子集;

2)会使用这些视图特征,记下剩余那些少量不能正确分类的数据。

第一点意味着,如果将不同网络进行集成,将能够把学习到的视图特征聚合起来,从而达到更高的测试精度

第二点意味着,单个模型不能学习所有的视图特性,不是因为它们没有足够的容量,而是因为没有足够的训练数据;大多数数据已经被现有的视图特征正确分类,因此在训练阶段,它们基本上不提供梯度。

5、知识蒸馏: 强制单个模型学习多个视图

基于上述视角,我们可以再来分析知识蒸馏是如何工作的。

在现实生活的场景中,一些汽车图像可能看起来“更像一只猫”:例如,一些汽车图像的前灯可能看起来像猫眼。当这种情况发生时,集成模型可以提供有意义的隐藏知识,例如**“汽车图像 X 有10% 像一只猫。”**

这里是个关键点。在训练单个神经网络模型时,如果没有学习“前灯”视图,剩下的视图或许仍然有可能根据别的视图将图像 x 标记为汽车,但它却无法匹配隐藏知识“图像 X 有10% 像猫”。

而在知识蒸馏的过程中,蒸馏模型会学习每一个可能的视图特征,来匹配集成的性能。需要注意的是,深度学习中知识蒸馏的关键是,作为一个神经网络,单个模型在特征学习中能够学习到集成的所有特征。这与实验中观察到的情况是一致的。(见图6)

在这里插入图片描述

图6: 知识蒸馏已经从集成中学习了大部分视图特性,因此在知识蒸馏之后对模型进行集成学习不会带来更多的性能提升。

6、自蒸馏: 集成与知识蒸馏的隐性结合

这个解释也可以用到知识自蒸馏中——训练一个模型来匹配另一个相同的架构的模型(但使用不同的随机种子)的输出,在某种程度上也能提高性能。

简单来理解,自蒸馏是知识蒸馏的一种特殊情况。

假设我们使用模型 𝐹2𝐹_2F2 从一个随机的初始化开始,来匹配另外一个模型 𝐹1𝐹_1F1 的输出。在这个过程中 𝐹2𝐹_2F2 一方面会学习 𝐹1𝐹_1F1 已经学习到特征子集,另一方面其能够学习到的特征子集也会受其随机初始化的影响。

这个过程,可以看做是:首先对两个单独的模型 𝐹1𝐹_1F1𝐹2𝐹_2F2 进行集成学习,然后蒸馏成 𝐹2𝐹_2F2

最终的 𝐹2𝐹_2F2 可能不一定涵盖数据集中所有可学习的视图,但它至少有学习所有视图(通过两个单个模型的集成学习数据库来覆盖)的潜力。这就是自蒸馏模型测试时性能提升的来源!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532361.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在墙上找垂直线_墙上如何快速找水平线

在装修房子的时候,墙面的面积一般都很大,所以在施工的时候要找准水平线很重要,那么一般施工人员是如何在墙上快速找水平线的呢?今天小编就来告诉大家几种找水平线的方法。一、如何快速找水平线1、用一根透明的软管,长度…

Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer(ViT)PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文,及其PyTorch实现,将整个ViT的代码做一…

Linux下的ELF文件、链接、加载与库(含大量图文解析及例程)

Linux下的ELF文件、链接、加载与库 链接是将将各种代码和数据片段收集并组合为一个单一文件的过程,这个文件可以被加载到内存并执行。链接可以执行与编译时,也就是在源代码被翻译成机器代码时;也可以执行于加载时,也就是被加载器加…

java 按钮 监听_Button的四种监听方式

Button按钮设置点击的四种监听方式注:加粗放大的都是改变的代码1.使用匿名内部类的形式进行设置使用匿名内部类的形式,直接将需要设置的onClickListener接口对象初始化,内部的onClick方法会在按钮被点击的时候执行第一个活动的java代码&#…

linux查看java虚拟机内存_深入理解java虚拟机(linux与jvm内存关系)

本文转载自美团技术团队发表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux与进程内存模型要理解jvm最重要的一点是要知道jvm只是linux的一个进程,把jvm的视野放大,就能很好的理解JVM细分的一些概念下图给出了硬件系统进程三个层面内存之间的关系.从硬件上…

java function void_Java8中你可能不知道的一些地方之函数式接口实战

什么时候可以使用 Lambda?通常 Lambda 表达式是用在函数式接口上使用的。从 Java8 开始引入了函数式接口,其说明比较简单:函数式接口(Functional Interface)就是一个有且仅有一个抽象方法,但是可以有多个非抽象方法的接口。 java8…

java jvm内存地址_JVM--Java内存区域

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域,如图:1.程序计数器可以看作是当前线程所执行的字节码的行号指示器,通俗的讲就是用来指示执行哪条指令的。为了线程切换后能恢复到正确的执行位置Java多线程是…

java情人节_情人节写给女朋友Java Swing代码程序

马上又要到情人节了,再不解风情的人也得向女友表示表示。作为一个程序员,示爱的时候自然也要用我们自己的方式。这里给大家上传一段我在今年情人节的时候写给女朋友的一段简单的Java Swing代码,主要定义了一个对话框,让女友选择是…

java web filter链_filter过滤链:Filter链是如何构建的?

在一个Web应用程序中可以注册多个Filter程序,每个Filter程序都可以针对某一个URL进行拦截。如果多个Filter程序都对同一个URL进行拦截,那么这些Filter就会组成一个Filter链(也叫过滤器链)。Filter链用FilterChain对象来表示,FilterChain对象中…

java final static_Java基础之final、static关键字

一、前言关于这两个关键字,应该是在开发工作中比较常见的,使用频率上来说也比较高。接口中、常量、静态方法等等。但是,使用频繁却不代表一定是能够清晰明白的了解,能说出个子丑演卯来。下面,对这两个关键字的常见用法…

java语言错误的是解释运行的_Java基础知识测试__A卷_答案

考试宣言:同学们, 考试考多少分不是我们的目的! 排在班级多少的名次也不是我们的初衷!我的考试的目的是要通过考试中的题目,检查大家在这段时间的学习中,是否已经把需要掌握的知识掌握住了,如果哪道题目你不会做,又或者做错了, 那么不用怕, 考完试后, 导师讲解的时候你要注意听…

java 持续集成工具_Jenkins-Jenkins(持续集成工具)下载 v2.249.2官方版--pc6下载站

Jenkins是一款基于java开发的持续集成工具,是一款开源软件,主要用于监控持续重复的工作,为开发者提供一个开发易用的软件平台,使软件的持续集成变成可能。。相关软件软件大小版本说明下载地址Jenkins是一款基于java开发的持续集成…

java中线程调度遵循的原则_深入理解Java多线程核心知识:跳槽面试必备

多线程相对于其他 Java 知识点来讲,有一定的学习门槛,并且了解起来比较费劲。在平时工作中如若使用不当会出现数据错乱、执行效率低(还不如单线程去运行)或者死锁程序挂掉等等问题,所以掌握了解多线程至关重要。本文从基础概念开始到最后的并…

java类构造方法成员方法练习_面向对象方法论总结 练习(一)

原标题:面向对象方法论总结 & 练习(一)学习目标1.面向对象与面向过程2.类与对象的概念3.类的定义,对象的创建和使用4.封装5.构造方法6.方法的重载内容1.面向对象与面向过程为什么会出现面向对象反分析方法?因为现实世界太复杂多变&#x…

mysql 统计查询不充电_MySql查询语句介绍,单表查询,来充电吧

mysql在网站开发中,越来越多人使用了,方便部署,方便使用。我们要掌握mysql,首先要学习查询语句。查询单个表的数据,和多个表的联合查询。下面以一些例子来先简单介绍下单表查询。操作方法01首先看下我们例子用到的数据表&#xff…

MySQL线上优化_线上MySQL千万级大表,如何优化?

前段时间应急群有客服反馈,会员管理功能无法按到店时间、到店次数、消费金额进行排序。经过排查发现是 SQL 执行效率低,并且索引效率低下。图片来自 Pexels应急问题商户反馈会员管理功能无法按到店时间、到店次数、消费金额进行排序,一直转圈…

php创建表设置编码,教您在Zend Framework里如何设置数据库编码以及怎样给数据表设定前缀!...

当我们在开发项目时..大家都会遇到一个问题就是:数据库的编码问题.当然我们不用Zend Framework做为项目开发的框架时..我们可以很快,很容易搞定这个小问题..但是当我们要使用Zend Framewok开发项目时..我们可能一时会不知道如何解决这个小问题..比如我就是这样的人..在开发这个…

python 怎么将数组转为列表_怎么将视频转为GIF动态图 表情包怎么制作

说到GIF,大家应该都不陌生了吧!尤其是在聊天中使用较多,似乎一言不合就开启了斗图模式,但是我们平时使用的GIF一般都是软件中自带的,其实自己制作也是很方便的,而且会发现很有趣,不但可以直接录…

proteus里面没有stm32怎么办_嵌入式单片机之stm32串口你懂了多少!!

stm32作为现在嵌入式物联网单片机行业中经常要用多的技术,相信大家都有所接触,今天这篇就给大家详细的分析下有关于stm32的出口,还不是很清楚的朋友要注意看看了哦,在最后还会为大家分享有些关于stm32的视频资料便于学习参考。点击…

tomcat不能解析php,tomcat不支持php怎么办

tomcat不支持php的解决办法:首先将“PHP/Java Bridge”下的相关文件复制到tomcat的lib目录下;然后修改tomcat安装目录下conf文件夹里的“web.xml”文件;最后重启tomcat即可。java开发者都知道,tomcat是用来部署java web项目的。这…