Decoupled Knowledge Distillation解耦知识蒸馏

Decoupled Knowledge Distillation解耦知识蒸馏

现有的蒸馏方法主要是基于从中间层提取深层特征,而忽略了Logit蒸馏的重要性为了给logit蒸馏研究提供一个新的视角,我们将经典的KD损失重新表述为两部分,即目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。我们实证研究并证明了两部分的效果:TCKD转移了关于训练样本“难度”的知识而NCKD是logit蒸馏有效的突出原因。更重要的是,我们揭示了经典KD损失是一个耦合公式,它(1)抑制了NCKD的有效性,(2)限制了平衡这两个部分的灵活性。为了解决这些问题,我们提出了解耦知识蒸馏(DKD),使TCKD和NCKD更有效和灵活地发挥其作用。

介绍

在过去的几十年里,计算机视觉领域已经被深度神经网络(DNN)彻底改变,它成功地促进了各种真实场景的任务,如图像分类、目标检测和语义分割。然而,大的网络通常受益于大的模型容量,引入了高计算和存储成本。在广泛部署轻量级模型的工业应用中,这样的成本并不可取。在文献中,降低成本的一个潜在方向是知识蒸馏(KD)。KD代表了一系列专注于将知识从重模型(教师)——转移到轻模型(学生)的方法,这可以在不引入额外成本的情况下提高轻模型的性能。

KD的概念在[12]中首次提出,通过最小化教师和学生预测logit之间的KL-Divergence来转移知识(图1a)。

image-20240303132727112

自[28]以来,大部分的研究注意力都集中在从中间层的深层特征中提取知识。与基于logit的方法相比,特征蒸馏的性能在各种任务上是否表现出色,因此,对logit蒸馏的研究很少涉及。然而,基于特征方法的训练成本并不令人满意,因为在训练期间引入了额外的计算和存储使用(例如,网络模块和复杂的操作)来提取深度特征。

Logit蒸馏需要边际的计算和存储成本,但性能较差。直观的说,logit蒸馏应该达到与特征蒸馏相当的性能,因为logit比深度特征处于更高的语义层。假设logit蒸馏的潜力收到未知原因的限制,导致性能不理想。为了振兴基于Logit的方法,我们通过深入研究KD的机制开始这项工作。首先,我们将分类预测分为两个层次(1)对目标类和所有非目标类进行二值预测;(2)对每个非目标类进行多类预测。在此基础上,我们将经典KD损失[12]重新表述为两部分,如图1b所示。一种是针对目标类的二元logit蒸馏另一种是针对非目标类的多类别logit蒸馏。为了简化期间,我们将其分别命名为目标分类和知识蒸馏(TCKD)和非目标分类知识蒸馏(NCKD)。重新配方使我们能够独立地研究这两部分的效果。

TCKD通过二元logit蒸馏传递知识,这意味这只提供目标类的预测,而每个非目标类的具体预测是未知的。一个合理的假设是,TCKD传递了关于训练样本“难易度”的知识,即知识描述了识别每个训练样本的难易程度。为了验证这一点,我们从三个方面设计实验来提高训练数据的“难度”,即更强的增强、更嘈杂的标签和具有固有挑战性的数据集。

NCKD只考虑非目标logit之间的知识。有趣的是,我们通过经验证明,仅应用NCKD就可以获得与经典KD相当甚至更好的结果,这表明非目标logit中包含的知识至关重要,这可能是突出的“暗知识”。

更重要的是,我们的重新表述表明,经典KD损失是一个高度耦合的表述(如图1b所示),这可能是logit蒸馏潜力有限的原因。首先,NCKD损失项被一个与教师对目标类别的预测置信度负相关的系数加权。因此较大的预测分数将导致较小的权重。这种耦合显著抑制了NCKD对良好预测训练样本的影响。这种抑制并不可取,因为教师对训练样本越有信息,可提供的知识越可靠越有价值。其次,TCKD和NCKD的意义是耦合的,即不允许分别对TCKD和NCKD进行加权。这种限制是不可取的,因为TCKD和NCKD应该分开考虑,因为它们的贡献来自不同的方面。

为了解决这些问题,我们提出了一种灵活高效的logit蒸馏方法,称为解耦知识蒸馏(DKD,图1b)DKD将NCKD损失从与教师置信度负相关的系数中解耦,将其替换为恒定值,从而提高了对预测良好的样本的蒸馏效率。同时,对NCKD和TCKD也进行了解耦,通过调整各部分权重,可以分别考虑NCKD和TCKD的重要性。

总的来说,我们的贡献总结如下:

(1)将经典的logit蒸馏分为TCKD和NCKD,为Logit蒸馏的研究提供了新的思路。

(2)我们揭示了由其高耦合公式引起的经典KD损失的局限性。NCKD与教师信心的耦合抑制了知识转移的有效性。TCKD与NCKD的耦合限制了平衡两部分的灵活性。

(3)为了克服这些局限性,我们提供了一种有效的logit蒸馏方法DKD。

重新思考知识蒸馏

在本节中,我们深入探讨知识蒸馏的机制。我们将KD损失重新表述为两部分的加权和,一部分与目标类相关,另一部分与目标类无关。我们探讨了知识蒸馏框架中每个部分的作用,并揭示了经典KD的一些局限性。受此启发,我们进一步提出了一种新的logit蒸馏方法,在各种任务上取得了显著的性能。

回顾KD

Notation对于第t类的训练样本,分类概率可以表示为P=image-20240303150007961,其中pi是第i类的概率,C是类的个数。p中的每个元素都可以通过softmax函数得到:
p i = e x p ( z i ) ∑ j = 1 C e x p ( z j ) p_i = \frac{exp(z_i)}{\sum_{j=1}^Cexp(z_j)} pi=j=1Cexp(zj)exp(zi)
其中zi代表第i类的对数。

为了区分于目标类相关和不相关的预测,我们定义了以下符号。b = image-20240303150447331表示目标类(pt)和其他所有非目标类(p\t)的二值概率,其计算公式为:

image-20240303150539198

同时,我们声明image-20240303150715008独立建模非目标类之间的概率(即,不考虑第t类)。每个元素的计算方法为:image-20240303150736308

Reformulation 在第一部分中,我们尝试用二元概率b和非目标类之间的概率p来重新表述KD。T和S分别表示老师和学生。经典KD使用kl散度作为损失函数,也可以写成2:

image-20240303151314489

根据等式(1)和等式(2)我们有image-20240303151721273,所以我们可以把等式(3)改写为:

image-20240303151806455

等式(4)可以改写为:

image-20240303151918101

如公式(5)所示,KD损失被重新表述为两项的加权和。image-20240303152823063表示目标类别的教师和学生的二元概率之间的相似度。因此,我们将其命名为目标类知识蒸馏(TCKD)。同时,image-20240303153038652表示非目标类中教师和学生概率的相似度,称为非目标类知识蒸馏(NVKD)。式(5)可以改写为:

image-20240303153129634

显然,NCKD的重建与image-20240303153158532是耦合的。

上述重新表述启发了我们对TCKD和NCKD的个体效应进行研究,这将揭示经典耦合表述的局限性。

TCKD和NCKD的影响

各部件的性能增益。我们分别研究了TCKD和NCKD对CIFAR-100的影响。选择ResNet、WideResNet(WRN)和ShuffleNet作为训练模型,其中考虑了相同和不同的架构。实验结果如表1,对于每个师生对,我们报告了(1)学生基线,(2)经典KD(其中同时使用TCKD和NCKD),(3)单一TCKD和(4)单一NCKD的结果。每个损失的权重设置为1.0(包括默认的交叉熵损失)。其它实现细节与第4节相同。

image-20240303155626665

直观地说,TCKD集中于与目标类相关的知识,因为相应的损失函数只考虑二进制概率。相反,NCKD侧重于非目标类别的知识。我们注意到单独使用TCKD对学生来说可能没有帮助(例如在ShufflerNet-V1上增加0.02%和0.12%)甚至是有害的(例如,在WRN-16-2上下降2.3%,在ResNet8-4上下降3.87%)。然而,NCKD的蒸馏性能与经典KD相当,甚至更好(例如,在ResNet8/4上,1.76% vs 1.13%)。消融结果表明靶类相关知识不如非靶类知识重要,为了深度研究这一现象,我们提供如下进一步的分析。

TCKD传递了关于训练样本“难度”的知识

根据等式(5),TCKD通过二值分类任务传递“暗知识”,这可能与样本的“难度“有关。例如,与image-20240303155803478的训练样本相比,image-20240303155813762的训练样本可能”更容易“让学生学习。由于TCKD传达了训练样本的“难度”,我们假设当训练数据变得具有挑战性时,有效性将被解释。然而,CIFRA-100训练集很容易过拟合。因此,教师提供的“难度”知识并不是信息性的。在这一部分中,我们从三个角度进行实验验证:训练数据越难,TCKD提供的好处越多。

(1)应用强增强是增加训练数据难度的一种直接方法。我们在CIFAR-100上使用AutoAugment训练ResNet32×4模型作为教师,获得了81.29%的top-1验证精度。对于学生,我们训练带/不带TCKD的ResNet8、4和ShufflerNetv1模型。表2的结果表明,如果应用强增强,TCKD可以获得显著的性能增益。

image-20240303161609391

(2)噪声标签也会增加训练数据的难度。我们在CIFAR-100上以{0.1,0.2,0.3}对称噪声比训练ResNet32×4模型作为教师,ResNet8×4模型作为学生,如下[7,35]。如表3所示,结果表明TCKD在噪声较大的训练数据上取得了更多的绩效提升。

image-20240303161939762

(3)挑战性的数据集(例如,ImageNet也被考虑。表4显示,TCKD可以在ImageNet上带来+0.32%的性能增益。

image-20240303162009395

最后,我们通过实验各种策略来增加训练数据的难度(如强增强、噪声标签、困难任务),证明了TCKD的有效性。结果证明,在提取更具挑战性的训练数据时,有关训练样本“难度“的知识可能更有用。

NCKD是logit蒸馏工作的重要原因,但受到很大的抑制。有趣的是,我们在表1中注意到,当仅应用NCKD时,性能与经典KD相当甚至更好。结果表明,非目标类的知识对logit蒸馏至关重要,可以成为突出的“暗知识”。然而,通过回顾方程(5),我们注意到NCKD损失与image-20240303162635731相耦合。其中,image-20240303162731869代表教师对目标类别的置信度。因此,更有置信度的预测会导致更小的NCKD权重。我们假设教师对训练样本越有信心,它提供的知识就越可靠,越有价值。然而,这种自信的预测高度抑制了损失权重。我们假设这一事实会限制知识蒸馏的有效性,这首先是由于我们在等式(5)中对KD的重新表述而研究的。

我们设计了一个消融实验来验证预测良好的样本确实比其他样本更好地传递知识。首先,我们根据image-20240303163021532对训练样本进行排序,并将其平均分成两个子集。为了清晰起见,一个子集包括image-20240303163212148前50%的样本,而其余样本在另一个子集中。然后,我们在每个子集上使用NCKD训练学生网络,以比较性能增益(而交叉熵损失仍然在整个集合上)。表5显示,在前50%的样本上使用NCKD可以获得更好的性能,这表明预测良好的样本的知识比其他样本更丰富。然而,预测良好的样本的损失权重被教师的高置信度所抑制。

image-20240303163329872

解耦知识蒸馏

至此,我们将经典KD损失重新表述为两个独立部分的加权和,进一步验证了TCKD的有效性,揭示了NCKD的抑制作用。具体来说,TCKD传递了关于训练样本“难度”的知识。TCKD可以在更具挑战性的训练数据上获得更显著的改进。NCKD在非目标类之间进行知识转移。当权重image-20240303163557024较小时,知识转移受到抑制。

本能地,TCKD和NCKD都是必不可少的,至关重要的。然而,在经典KD公式中,TCKD和NCKD从以下几个方面耦合。

(1)首先,NCKD与image-20240303163724801耦合,这可以抑制预测良好的样本上的NCKD。由于表5的结果表明,预测良好的样本可以带来更多的性能增益,因此耦合形式可能会限制NCKD的有效性。

(2)另一方面,在经典KD框架下,NCKD与TCKD的权重是耦合的。不允许为了平衡重要性而改变每个词的权重。我们认为TCKD和NCKD应该考虑他们的贡献来自不同的方面而分离。

基于我们对KD的重新表述,我们提出了一种新的logit蒸馏方法——解耦知识蒸馏(DKD)。我们提出的DKD在解耦公式中独立考虑了TCKD和NCKD。具体来说,我们分别引入了两个超参数作为TCKD和NCKD的权重,DKD的损失函数为:

image-20240303164149479

在DKD中,image-20240303164236706会抑制NCKD的有效性,使用image-20240303164247291代替。此外,还允许调整两个超参数以平衡TCKD和NCKD的重要性。DKD通过解耦NCKD和TCKD,为logit蒸馏提供了高效、灵活的方法。算法1提供了DKD的伪代码。

image-20240303164406442

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/716848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++之旅——第四弹

大家好啊,这里是c之旅第三弹,跟随我的步伐来开始这一篇的学习吧! 如果有知识性错误,欢迎各位指正!!一起加油!! 创作不易,希望大家多多支持哦! 本篇文章的主…

如何对比 MySQL 主备数据的一致性?

随着业务范围的扩大,很多企业为了保障核心业务的高可用性,选择了 MySQL 主从架构,这一套方案通常具备主备数据同步、数据备份与恢复、读写分离、高可用切换等特性,是一种相当成熟可靠的数据库架构方案。然而这套方案在特定情况下可…

Redis小白入门教程

Redis入门教程 1. Redis入门1.1 Redis简介1.2 Redis服务启动与停止1.2.1 Redis下载1.2.2 服务启动命令1.2.3 客户端连接命令1.2.4 修改Redis配置文件 2. Redis数据类型2.1 五种常用数据类型介绍2.1.1 字符串操作命令2.1.2 哈希操作命令2.1.3 列表操作命令2.1.4 集合操作命令2.1…

双周回顾#006 - 这三个月

断更啦~~ 上次更新时间 2023/11/23, 断更近三个月的时间。 先狡辩下,因为忙、着实忙。因为忙,心安理得给断更找了个借口,批评下自己~~ 这三个月在做啥?跨部门援助,支援公司互联网的 ToC 项目,一言难尽。 …

【C语言】InfiniBand 驱动mlx4_ib_init和mlx4_ib_cleanup

一、中文讲解 这两个函数是Linux内核模块中对于Mellanox InfiniBand 驱动程序初始化和清理的函数。 mlx4_ib_init()函数是模块初始化函数,使用__init宏标注,表示该函数只在模块加载时运行一次。 函数执行的步骤如下: 1. 通过alloc_ordered_w…

数据结构——lesson5栈和队列详解

hellohello~这里是土土数据结构学习笔记🥳🥳 💥个人主页:大耳朵土土垚的博客 💥 所属专栏:数据结构学习笔记 💥对于顺序表链表有疑问的都可以在上面数据结构的专栏进行学习哦~感谢大家的观看与…

ElasticSearch开篇

1.ElasticSearch简介 1.1 ElasticSearch(简称ES) Elasticsearch是用Java开发并且是当前最流行的开源的企业级搜索引擎。能够达到实时搜索,稳定,可靠,快速,安装使用方便。 1.2 ElasticSearch与Lucene的关…

模拟器抓HTTP/S的包时如何绕过单向证书校验(XP框架)

模拟器抓HTTP/S的包时如何绕过单向证书校验(XP框架) 逍遥模拟器无法激活XP框架来绕过单向的证书校验,如下图: ​​ 解决办法: 安装JustMePlush.apk安装Just Trust Me.apk安装RE管理器.apk安装Xposedinstaller_逍遥64位…

智能边缘小站 CloudPond(低延迟、高带宽和更好的数据隐私保护)

智能边缘小站 CloudPond(低延迟、高带宽和更好的数据隐私保护) 边缘小站的主要功能是管理用户在线下部署的整机柜设施,一个边缘小站关联一个华为云指定的区域和一个用户指定的场地,相关的资源运行状况监控等。 边缘计算 迈入5G和AI时代,新…

利用redis实现秒杀功能

6、秒杀优化 这个是 图灵 的redis实战里面的一个案例 6.1 秒杀优化-异步秒杀思路 我们来回顾一下下单流程 当用户发起请求,此时会请求nginx,nginx会访问到tomcat,而tomcat中的程序,会进行串行操作,分成如下几个步骤…

基于单片机的红外遥控解码程序设计与实现

摘要:该文介绍基于士兰半导体芯片(SC6122)的红外发射遥控器,通过单片机解码程序,实现红外遥控信号的解码和接收。红外接收头与单片机特定的引脚连接,通过设置单片机定时计数器,采样来自红外接收头的高、低电平宽度解码遥控信号。该解码程序设计主要应用在LED数码显示控制…

电机的极数和槽数,机械角度和电角度,霍尔IC,内外转子

什么是电机的极数和槽数? 【第7集】② 正弦波驱动的转矩脉动、正弦电流的时序和相位变化、超前角控制(超前角调整)、正弦波驱动的各种波形 - 电源设计电子电路基础电源技术信息网站_罗姆电源设计R课堂 (rohm.com.cn) 下面为您介绍表示电机…

Java虚拟机(JVM)从入门到实战【上】

Java虚拟机(JVM)从入门到实战【上】,涵盖类加载,双亲委派机制,垃圾回收器及算法等知识点,全系列6万字。 一、基础篇 P1 Java虚拟机导学课程 P2 初识JVM 什么是JVM Java Virtual Machine 是Java虚拟机。…

3.2日-线性模型,基础优化方法,线性回归从零开始实现

3.2日-线性模型,基础优化方法,线性回归从零开始实现 1线性模型衡量预估质量训练数据总结2基础优化方法3 线性回归从零开始实现 1线性模型 衡量预估质量 训练数据 总结 2基础优化方法 梯度下降是一种优化算法,常用于机器学习和深度学习中&…

进程的信号

目录 信号(signal)入门 技术应用角度的信号 注意 用kill -l命令可以察看系统定义的信号列表 信号处理常见方式概览 产生信号 1.通过终端(键盘)按键产生信号 signal函数 2. 调用系统函数向进程发信号 kill 函数 raise 函数 3.由软件条件产生的信号 alarm 函数 4.硬…

(学习日记)2024.03.01:UCOSIII第三节 + 函数指针 (持续更新文件结构)

写在前面: 由于时间的不足与学习的碎片化,写博客变得有些奢侈。 但是对于记录学习(忘了以后能快速复习)的渴望一天天变得强烈。 既然如此 不如以天为单位,以时间为顺序,仅仅将博客当做一个知识学习的目录&a…

Kubernetes: 本地部署dashboard

本篇文章主要是介绍如何在本地部署kubernetes dashboard, 部署环境是mac m2 下载dashboard.yaml 官网release地址: kubernetes/dashboard/releases 本篇文章下载的是kubernetes-dashboard-v2.7.0的版本,通过wget命令下载到本地: wget https://raw.githubusercont…

【Python】进阶学习:pandas--isin()用法详解

【Python】进阶学习:pandas–isin()用法详解 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅…

【NDK系列】Android tombstone文件分析

文件位置 data/tombstone/tombstone_xx.txt 获取tombstone文件命令: adb shell cp /data/tombstones ./tombstones 触发时机 NDK程序在发生崩溃时,它会在路径/data/tombstones/下产生导致程序crash的文件tombstone_xx,记录了死亡了进程的…

单细胞Seurat - 细胞聚类(3)

本系列持续更新Seurat单细胞分析教程,欢迎关注! 维度确定 为了克服 scRNA-seq 数据的任何单个特征中广泛的技术噪音,Seurat 根据 PCA 分数对细胞进行聚类,每个 PC 本质上代表一个“元特征”,它结合了相关特征集的信息。…