知识蒸馏(paper翻译)

paper:Distilling the Knowledge in a Neural Network

摘要:

提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。
不幸的是,使用整个模型集合进行预测非常麻烦,并且计算成本可能太高,无法部署到大量用户,尤其是在单个模型是大型神经网络的情况下。
Caruana 和他的合作者 [1] 已经证明,可以将集成中的知识压缩到单个模型中,该模型更容易部署,并且我们使用不同的压缩技术进一步开发了这种方法。
我们在 MNIST 上取得了一些令人惊讶的结果,并且表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进频繁使用的商业系统的声学模型。
我们还引入了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类别。 与专家的混合不同,这些专业模型可以快速并行地进行训练。

Introduction

许多昆虫都有幼虫形态和完全不同的成虫形态,幼虫形态经过优化,可以从环境中获取能量和营养,而成虫形态则可以满足不同的旅行和繁殖要求。

在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它并不需要实时操作,因此可以使用大量的计算。

然而,部署到大量用户对延迟和计算资源有更严格的要求。 与昆虫的类比表明,如果可以更轻松地从数据中提取结构,我们应该愿意训练非常繁琐的模型(后面称为大模型)。
大模型可能是单独训练的模型的集合,也可能是使用非常强大的正则化器(例如 dropout)训练的单个非常大的模型[9]。

一旦繁琐的模型训练出来,我们就可以使用不同类型的训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型。 Rich Caruana 及其合作者已经率先提出了该策略的一个版本 [1]。 在他们的重要论文中,他们令人信服地证明,通过大型模型集合获得的知识可以转移到单个小型模型中。

可能阻止对这种非常有前途的方法进行更多研究的一个概念障碍是,我们倾向于使用学习到的参数值来识别经过训练的模型中的知识,这使得我们很难看到如何改变模型的形式但保持相同的知识。

知识的一个更抽象的观点是,它是从输入向量到输出向量的学习映射,将其从任何特定的实例化中解放出来。
对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练后的模型将概率分配给所有不正确的答案,即使这些概率非常小,其中一些也比其他概率大得多。

错误答案的相对概率告诉我们很多关于大模型如何泛化的信息。 例如,BMW的图像可能只有很小的机会被误认为是垃圾车,但这种错误的可能性仍然比将其误认为是胡萝卜高很多倍。

一般认为,用于训练的目标函数应尽可能地反映用户的真实目标。尽管如此,模型通常被训练为优化训练数据上的性能,而真正的目标是要对新数据具有良好的泛化能力。
显然,更好的做法是训练模型以便它们能够很好地泛化,但这需要关于正确泛化方式的信息,而这些信息通常是不可用的。

然而,当我们将大模型的知识提炼到小模型时,可以训练小模型与大型模型相同的方式进行泛化。
如果大模型泛化得好,例如,因为它是多个不同模型大型集合的平均,那么训练小模型以相同方式泛化,在测试数据上通常会比按照常规方式在同一个训练集上训练的小模型表现更好,训练集就是训练大模型的集合的。

将大模型的泛化能力转移到小模型的一个明显方法是使用大模型产生的class probability作为训练小模型的“soft targets”。
对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集。 当大模型是简单模型的大型集合时,我们可以使用它们各自的预测分布的算术或几何平均值作为soft targets。

当soft targets具有高熵时,它们在每个训练case中提供的信息比hard targets多得多,并且训练case之间梯度的方差要小得多,因此小模型可以用更少的数据,更大的learning rate进行训练。

对于像MNIST这样的任务,大模型几乎总以很高的置信度得出正确答案,大量关于学习function的信息寄存在soft targets中非常小概率的比率里。例如,一个版本中,2可能以10-6的概率被认为是3,10-9的概率被认为是7,而另一个版本可能恰好相反。这是有用的信息,因为它定义了数据的丰富的类似结构(即它指出哪些2看起来像3,哪些看起来像7),但在transfer阶段它对交叉熵损失函数的影响非常小,因为这些概率接近于零。

Caruana及其合作者通过使用logits(最后的softmax层的input)而不是用由softmax产生的概率作为学习小模型的target 来避开这个问题,并且他们最小化大模型和小模型产生的logits之间的平方差。更通用的解决方案,称为“蒸馏”,是将最后的softmax层的温度提高,直到大模型产出一套合适的soft target。然后训练小模型时用相同的高温,以匹配这些soft targets。我们稍后将展示,匹配大模型的logits实际上是蒸馏的一个特殊case。
这里的 “温度” 在后面的公式中体现

用于训练小模型的转移集可以完全由未标记数据组成[1],或者我们可以使用原始训练集。我们发现使用原始训练集效果很好,尤其是如果我们在目标函数中增加一个小项,鼓励小模型预测真实的target, 并且匹配由大模型提供的soft target。

通常,小模型无法完全匹配soft target,而在正确答案的方向上犯错被证明是有帮助的。

蒸馏

softmax的input称为logits, 用 z i z_{i} zi表示,
softmax的output称为概率,用 q i q_{i} qi表示。
神经网络通常用一个softmax层把logits转为概率,通过把 z i z_{i} zi与其他概率作比较。
在这里插入图片描述
公式里面的T就是上面说的蒸馏的温度。T通常是1. 更高的T产生更加soft的概率分布。

如何设置温度T?

在最简单的蒸馏形式中,准备一个transfer set数据集,它的label是大模型通过调高T产生的soft target,训练蒸馏模型时也要用同样的T,训练完成后T=1.
通过在transfet set上训练蒸馏模型,知识就被转移到了蒸馏模型。

同时使用label和soft target

当所有或部分transfer set的正确label已知时,还可以通过训练蒸馏模型来生成正确的标签来显着改进该方法。
一种方法是使用正确的label来修改soft target,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。
第一个目标函数是与soft target的交叉熵,并且该交叉熵是让蒸馏模型和产生soft target的 大模型用相同的温度T(softmax中)。softmax 中与用于从繁琐模型生成软目标相同的高温来计算的。
第二个目标函数是和正确label的交叉熵。 这是在蒸馏模型的 softmax 中还是用完全相同的logits计算,但T= 1。
我们发现,通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。

由于soft target产生的梯度幅度 相当于缩放了 1/T2 ,因此在同时使用hard 和 soft targets时将其乘以 T 2 非常重要。 这确保了如果在元参数实验时用于蒸馏的温度T发生变化,hard和soft target的相对贡献保持大致不变。

Matching logits是蒸馏的一种特殊形式

PS: 前面introduction部分提到过,用softmax的input, 也叫logits, 代替softmax输出的概率作为学习小模型的target,来避开概率过小的问题,通过最小化大模型和小模型产生的logits之间的平方差。
现在说明这种方法为什么是蒸馏的一种形式。

transfer set中每个case都对蒸馏模型的每个logits z i z_{i} zi贡献出cross-entropy梯度 d C / d z i dC/dz_{i} dC/dzi.
如果大模型有logits v i v_{i} vi, 产生了soft target概率 p i p_{i} pi, 训练在温度T下完成.
那么梯度为:
在这里插入图片描述
如果温度比logits的幅度大,那么可以近似为:
在这里插入图片描述

假设每个transfer case的logits都是0均值的,即在这里插入图片描述
那么(3)可以简化为:

在这里插入图片描述

所以在温度T高时,如果logits对每个tranfer case都是0 均值,那么蒸馏等同于最小化 1 / 2 ( z i − v i ) 1/2(z_{i} - v_{i}) 1/2(zivi).
在T比较低时,蒸馏在matching logits上的attetion就少很多,因为它们比平均值负很多。
这是潜在的优势,因为这些logits几乎完全不受大模型的cost function的约束,因此它们可能非常noisy。
另一方面,非常负的logits可能会传达有关通过大模型获得的知识的有用信息。 这些影响中哪一个占主导地位是一个经验问题。 我们表明,当蒸馏模型太小而无法捕获大模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负logtis可能会有所帮助。

MNIST实验

为了了解蒸馏的效果如何,我们在所有 60,000 个训练案例上训练了一个大型神经网络,该神经网络具有两个隐藏层,每个隐藏层包含 1200 个校正线性隐藏单元。 该网络使用 dropout 和权重约束进行了强烈正则化,如 [5] 中所述。 Dropout 可以被视为训练共享权重的指数级大模型集合的一种方法。 此外,输入图像在任何方向上抖动最多两个像素。 该网络出现了 67 个测试错误,而具有两个隐藏层(由 800 个校正线性隐藏单元且无正则化)的较小网络出现了 146 个错误。 但是,如果仅通过添加在 20 ℃ 的温度下匹配大网络产生的软目标的附加任务来对较小的网络进行正则化,则它会出现 74 个测试错误。 这表明soft target可以将大量知识转移到蒸馏模型中,包括如何泛化从translated训练数据中学到的知识,即使转移集不包含任何translations。

当蒸馏网络的两个隐藏层中每个都有 300 个或更多units时,所有高于 8 的温度都会给出相当相似的结果。 但当这从根本上减少到每层 30 个units时,2.5 至 4 范围内的温度明显优于更高或更低的温度。

然后,我们尝试从传输集中省略数字 3 的所有示例。 所以从蒸馏模型的角度来看,3是一个它从未见过的神话数字。 尽管如此,蒸馏模型仅出现 206 个测试错误,其中 133 个位于测试集中的 1010 个三元组上。

大多数错误是由于3这个类别的学习bias太低而引起的。 如果此偏差增加 3.5(这会优化测试集的整体性能),则蒸馏模型会出现 109 个错误,其中 14 个错误位于 3 上。 因此,在正确的偏差下,尽管在训练期间从未见过 3,但蒸馏模型在测试 3 中的正确率达到 98.6%。 如果传输集仅包含训练集中的 7 和 8,则蒸馏模型的测试误差为 47.3%,但当 7 和 8 的偏差减少 7.6 以优化测试性能时,测试误差将降至 13.2%。

Discussion

我们已经证明,蒸馏对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型非常有效。 在 MNIST 上,即使用于训练蒸馏模型的传输集缺少一个或多个类的任何示例,蒸馏也能表现得非常好。 对于 Android 语音搜索所使用的深度声学模型版本,我们已经证明,通过训练深度神经网络集合所实现的几乎所有改进都可以被提炼为相同大小的单个神经网络, 部署起来要容易得多。
对于非常大的神经网络,甚至训练一个完整的集合也是不可行的,但是我们已经证明,经过很长时间训练的单个非常大的网络的performance 可以通过学习大量的专家网络来显着提高 ,每个专家网络都学会区分高度混乱的集群中的类别(通过大量专家网络进一步区分类别,是帮助的性质,并不是蒸馏)。 我们还没有证明我们可以将专家的知识蒸馏回单一的大网络中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656611.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

操作系统A-第四和五章(存储器)作业解析

目录 1、在请求分页系统中,某用户程序的逻辑地址空间为 16 页,每页 1KB,分配的内存空间为 8KB。假定某时刻该用户的页表如下表所示。 试问:(1)逻辑地址 184BH 对应的物理地址是多少?(用十六进制表示&…

基于SSM的二手车交易网站设计与实现(有报告)。Javaee项目。ssm项目。

演示视频: 基于SSM的二手车交易网站设计与实现(有报告)。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过…

全国网络安全行业职业技能大赛WP

word_sercet 文档被加密 查看图片的属性 在备注可以看到解压密码 解密成功 在选项里面把隐藏的文本显示出来 可以看到ffag easy_encode 得到一个bmp二维码 使用qr research 得到的密文直接放瑞士军刀 base32解码base64解码hex解码 dir_pcap 直接搜索flag 发现flag…

mysql之基本查询

基本查询 一、SELECT 查询语句 一、SELECT 查询语句 查询所有列 1 SELECT *FORM emp;查询指定字段 SELECT empno,ename,job FROM emp;给字段取别名 SELECT empno 员工编号 FROM emp; SELECT empno 员工编号,ename 姓名,job 岗位 FROM emp; SELECT empno AS 员工编号,ename …

Codeforces Round 799 (Div. 4)

目录 A. Marathon B. All Distinct C. Where’s the Bishop? D. The Clock E. Binary Deque F. 3SUM G. 2^Sort H. Gambling A. Marathon 直接模拟 void solve() {int ans0;for(int i1;i<4;i) {cin>>a[i];if(i>1&&a[i]>a[1]) ans;}cout<&l…

离线生成双语字幕,一键生成中英双语字幕,基于AI大模型,ModelScope

离线生成双语字幕整合包,一键生成中英双语字幕,基于AI大模型 制作双语字幕的方案网上有很多&#xff0c;林林总总&#xff0c;不一而足。制作双语字幕的原理也极其简单&#xff0c;无非就是人声背景音分离、语音转文字、文字翻译&#xff0c;最后就是字幕文件的合并&#xff0c…

鸿蒙南向开发——GN快速入门指南

运行GN(Generate Ninja) 运行gn&#xff0c;你只需从命令行运行gn&#xff0c;对于大型项目&#xff0c;GN是与源码一起的。 对于Chromium和基于Chromium的项目&#xff0c;有一个在depot_tools中的脚本&#xff0c;它需要加入到你的PATH环境变量中。该脚本将在包含当前目录的…

常用芯片学习——CD4094芯片

CD4094 8位移位寄存器/3态输出缓冲器 使用说明 CD4094是由一个 8 位串行移位寄存器和一个 3 态输出缓冲器组成的 CMOS 集成电路。寄存器带有存储锁存功能&#xff0c;集成电路根据 STROBE 信号确定锁存器是否接收移位寄存器各位数据&#xff0c;数据是否由锁存器传输到 3 态输…

【教学类-35-23】20240130“红豆空心黑体”不能显示的汉字

作品展示&#xff1a; 背景需求 使用红豆空心黑体制作幼儿字帖&#xff08;涂色版&#xff09; 【教学类-35-22】正式版 20240129名字字卡3.0&#xff08;15CM正方形手工纸、先男后女&#xff0c;页眉是黑体包含全名&#xff0c;名字是红豆空心黑体&#xff09;-CSDN博客文章…

线性代数---------学习总结

线性代数之行列式 行列式的几条重要的性质 1.某两行某两列交换位置之后&#xff0c;值变号 2.行列式转置&#xff0c;值不变 3.范德蒙德行列式&#xff0c;用不同行的公比做一系列的累乘运算 4.把某一行的行列式加到另一行上&#xff0c;利用他们之间的倍数关系&#xff0…

Could not resolve host: github.com问题解决

git clone的时候发现机器无法解析github.com&#xff0c;其实应该改用ssh协议去clone&#xff0c;但是我用的是公用的机器&#xff0c;密钥对一直没配置好&#xff0c;所以也就堵死了。那么如果想让机器能解析github.com&#xff0c;&#xff08;机器本身没有ping命令&#xff…

Python XPath解析html出现⋆解决方法 html出现#123;解决方法

前言 爬网页又遇到一个坑&#xff0c;老是出现乱码&#xff0c;查看html出现的是&#数字;这样的。 网上相关的“Python字符中出现&#的解决办法”又没有很好的解决&#xff0c;自己继续冲浪&#xff0c;费了一番功夫解决了。 这算是又加深了一下我对这些iso、Unicode编…

MySQL原理(二)存储引擎(3)InnoDB

目录 一、概况&#xff1a; 1、介绍&#xff1a; 2、特点&#xff1a; 二、体系架构 1、后台线程 2、内存池&#xff08;缓冲池&#xff09; 三、物理结构 1、数据文件&#xff08;表数据和索引数据&#xff09; 1.1、作用&#xff1a; 1.2、共享表空间与独立表空间 …

计算机网络——静态路由的配置实验

1.实验题目 实验四&#xff1a;静态路由的配置 2.实验目的 1.了解路由器的基本配置。 2.实现对路由器的静态配置。 3.了解Ping命令和trace的原理和使用 3.实验任务 &#xff08;1&#xff09;路由器的基本配置&#xff1a;关闭域名解释&#xff1b;设置路由器接口 IP 地…

网络地址相关函数一网打尽

这块的函数又多又乱&#xff0c;今天写篇日志&#xff0c;以后慢慢补充 1. 网络地址介绍 1.1 ipv4 1.1.1 点、分十进制的ipv4 你对这个地址熟悉吗&#xff1f; 192.168.10.100&#xff0c;这可以当做一个字符串。被十进制数字、 “ . ”分开。IP地址的知识就不再多讲…

一文速学-selenium高阶操作连接已存在浏览器

前言 不得不说selenium不仅在自动化测试作为不可或缺的工具&#xff0c;在数据获取方面也是十分好用&#xff0c;能够十分快速的见到效果&#xff0c;这都取决于selenium框架的足够的灵活性&#xff0c;甚至在一些基于web端的自动化办公都十分有效。 通过selenium连接已经存在…

【解决】No match for argument: gflags-devel

背景 在centos-8中安装gflags-devel&#xff0c;直接dnf安装&#xff0c;失败了。 [rootpcs2 ~]# sudo dnf -y install gflags-devel Extra Packages for Enterprise Linux 8 - x86_64 Extra Packages…

什么是Vue Vue入门案例

一、什么是Vue 概念&#xff1a;Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套 构建用户界面 的 渐进式 框架 Vue2官网&#xff1a;Vue.js 1.什么是构建用户界面 基于数据渲染出用户可以看到的界面 2.什么是渐进式 所谓渐进式就是循序渐进&#xff0c;不一定非得把V…

华为radius认证

组网需求 如图1所示&#xff0c;用户同处于huawei域&#xff0c;Router作为目的网络接入服务器。用户需要通过服务器的远端认证才能通过Router访问目的网络。在Router上的远端认证方式如下&#xff1a; Router对接入用户先用RADIUS服务器进行认证&#xff0c;如果认证没有响应…

(M)UNITY三段攻击制作

三段攻击逻辑 基本逻辑&#xff1a; 人物点击攻击按钮进入攻击状态&#xff08;bool isAttack&#xff09; 在攻击状态下&#xff0c; 一旦设置的触发器&#xff08;trigger attack&#xff09;被触发&#xff0c;设置的计数器&#xff08;int combo&#xff09;查看目前攻击…