「AI模型瘦身术」——知识蒸馏技术综述

使用KD原因

遇到问题:从产业发展的角度来看工业化将逐渐过渡到智能化,边缘计算逐渐兴起预示着 AI 将逐渐与小型化智能化的设备深度融合,这也要求模型更加的便捷、高效、轻量以适应这些设备的部署。

解决方案:知识蒸馏技术

知识蒸馏的关键点

如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习最根本的目的在于训练出在某个问题上泛化能力强的模型。

泛化能力强: 在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。

而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。

而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。

一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“soft target”。

KD的训练过程和传统的训练过程的对比

传统training过程(hard targets): 对ground truth求极大似然

KD的training过程(soft targets): 用large model的class probabilities作为soft targets

KD的训练过程为什么更有效?

softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

【举个例子】

在手写体数字识别任务MNIST中,输出类别有10个。

假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。

这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。 下图为知识蒸馏的通用形式。

知识传递形式

原始知识蒸馏(Vanilla Knowledge Distillation)仅仅是从教师模型输出的软目标中学习出轻量级的学生模型。

然而,当教师模型变得更深时,仅仅学习软目标是不够的。

因此,我们不仅需要获取教师模型输出的知识,还需要学习隐含在教师模型中的其它知识,比如有输出特征知识、中间特征知识、关系特征知识和结构特征知识。

标签知识是神经网络对样本数据最终的预测输出中包含的潜在信息,这也是目前蒸馏过程中最简单、应用最多的方式。

标签知识(输出特征知识)通常指的是教师模型的最后一层特征,主要包括逻辑单元和软目标的知识。标签知识(输出特征知识)知识蒸馏的主要思想是促使学生能够学习到教师模型的最终预测,以达到和教师模型一样的预测性能。

原始知识蒸馏是针对分类任务来提出的仅包含类间相似性的软目标知识,然而其它任务(如目标检测)网络最后一层特征输出中还可能包含有目标定位的信息。

换句话说,不同任务教师模型的最后一层输出特征是不一样的。因此,本文根据任务的不同对输 出特征知识分别进行归纳和分析,如表 1 所示。

Hinton 等人最早提出的知识蒸馏方法就属于目标分类的标签知识(输出特征知识)。由于经过“蒸馏温度”调节后的软标签中具有很多不确定信息,通常的研究认为这其中反映了样本间的相似度或干扰性、样本预测的难度,因此标签知识又被称为“暗知识”。

  • 为了有效地解决基于聚类的算法中的伪标签噪声的问题,Ge等人[45]利用“同步平均教学”的蒸馏框架进行伪标签优化,核心思想是利用更为鲁棒的“软”标签对伪标签进行在线优化。

  • MLP[46]提出了基于元学习(Meta - learning)自适应生成目标分布的方法,用于教师和学生模型的伪标签学习过程.利用一个筛选网络从目标检测模型预测的伪标签中区分出正例和负例,将正例用于下一阶段的半监督自训练过程,可以有效提升标签数据的利用率[43]。

  • Xie等人[4]利用有监督训练学生模型自身,在自蒸馏训练中额外地引入无标签噪声数据产生伪标签,将ImageNet的Top-1识别结果提高了约1%.对于标签知识蒸馏方法本身,已经有非常多的变体和应用,主要是从改进蒸馏过程、挖掘标签信息、去除干扰等方面,提升学生模型的性能.

  • Gao等人[47]实现了一种简单的逐阶段的标签蒸馏训练过程,在梯度下降训练过程中,每次只更新学生网络的一个模块,从前至后直到全部更新完成。

  • 根据Mirzadeh等人[48]的研究发现,并不是教师模型性能越高对于学生模型的学习越有利,当教师-学生模型之间的差距过大时,会导致学生难以从教师模型获得提升.为此,他们提出使用辅助教师策略来逐渐缩小教师和学生之间的学习差距,取得更好的蒸馏效果.

  • 同样是为了缩小教师 - 学生之间的学习差距,Yang等人[49]则提出利用教师模型在每个训练周期更新的中间模型产生的标签知识指导学生模型.为了充分挖掘标签信息、去除干扰,Müller等人[50]采用了子类别蒸馏方法,将原标签分组合并参与软标签蒸馏学习;

  • 文献[51]则研究了蒸馏损失函数对犔2范数和归一化的软标签的作用,提出使用球面空间度量蒸馏的方法去除范数的影响;

  • Zhang等人[52]关注了样本权重的影响,通过预测不确定性自适应分配样本权重,改善蒸馏过程;

  • Wu等人[53]提出了同伴协同蒸馏,通过训练多个分支网络并将其他训练较强教师的 logits 知识转移给同伴,有利于模型的稳定和提高蒸馏的质量。

最早使用教师模型中间特征知识的是 FitNets[27],其主要思想是促使学生的隐含层能预测出与教师隐含层相近的输出。

知识传递方式中有同构蒸馏和异构蒸馏,主要就是区分 是否:教师和学生模型的架构相似或属于同一系列的、层与层(Layer -to - Layer)或块与块(Block - to - Block)之间一一对应;不过通过这几年的实验来看,这并没有什么区别

不同知识传递形式的效果

如图所示,不同的知识传递形式,相比是有差异的,使用经典的KD标签知识是还不错的;使用特征间的,有较多都不如开山鼻祖KD;不过近期又有更多优化,比如使用互信息与对比学习的方法;

温度的特点

在回答这个问题之前,先讨论一下温度T的特点

  1. 原始的softmax函数是 𝑇=1 时的特例, 𝑇<1 时,概率分布比原始更“陡峭”, 𝑇1 时,概率分布比原始更“平缓”。

  2. 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i) 𝑇=∞ , 此时softmax的值是平均分布的;(ii) 𝑇→0,此时softmax的值就相当于 𝑎𝑟𝑔𝑚𝑎𝑥 , 即最大的概率处的值趋近于1,而其他值趋近于0)

  3. 不管温度T怎么取值,Soft target都有忽略相对较小的 𝑝𝑖 携带的信息的倾向

温度代表了什么,如何选取合适的温度?

温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:

  1. 从有部分信息量的负标签中学习 --> 温度要高一些

  2. 防止受负标签中噪声的影响 -->温度要低一些

总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)

CRD 对比学习

首先 CRD是2020年提出的新模式的蒸馏方法,使用对比学习,在这年对比了12个KD方法都是最好的,其中,CRD+KD两个方法合在一起更好,相当于两个维度的知识传递的监督,在2023年有基于CRD实现的CRCD,效果好一点,方案是差不多的;

知识提炼(KD)将知识从一个深度学习模型(教师)转移到另一个深度学习模型(学生)。Hinton等人(2015)最初提出的目标是将教师和学生输出之间的KL差异最小化。当输出是一个分布,例如类上的概率质量函数时,该公式具有直观意义。然而,我们通常希望传递有关representation的知识。例如,在“跨模态蒸馏”问题中,我们可能希望将图像处理网络的表示转移到声音(Aytar等人,2016)或深度(Gupta等人,2016)处理网络,这样图像的深度特征和相关的声音或深度特征高度相关。在这种情况下,KL发散是不确定的。

表征知识是结构化的——维度表现出复杂的相互依赖性。最初的KD目标(Hinton等人,2015年)将所有维度视为独立的,以输入为条件。让yT成为老师的输出,yS成为学生的输出。那么原始的KD目标函数ψ具有全因子形式:. 这种带因素的目标不足以传递结构知识,即输出维度i和j之间的依赖关系。这与图像生成中的情况类似,在图像生成中,由于输出维度之间的独立性假设,L2目标会产生模糊的结果。

为了克服这个问题,我们想要一个目标,捕捉相关性和高阶输出依赖性。为了实现这一点,在本文中,我们利用了对比目标家族(Gutmann&Hyvärinen,2010;Oord等人,2018;Arora等人,2019;Hjelm等人,2018)。近年来,这些目标函数已成功地用于密度估计和表征学习,尤其是在自我监督环境中。在这里,我们让他们适应从一个深层网络到另一个深层网络的知识蒸馏任务。我们表明,致力于研究表现空间很重要,类似于最近的工作,如Zagoruyko和Komodakis(2016a);Remero等人(2014年)。然而,请注意,这些工作中使用的损失函数并没有明确尝试捕捉表征空间中的相关性或高阶相关性。

图1:我们考虑的三种提取设置:(a)压缩模型,(b)将知识从一种模式(例如RGB)转移到另一种模式(例如深度),(c)将网络集合提取到单个网络中。对比目标鼓励教师和学生将相同的输入映射到接近的表示(在某些度量空间中),并将不同的输入映射到遥远的表示,如阴影圈所示。

我们的目标是最大化教师和学生之间的互信息的下限。我们发现,这会在多个知识转移任务中产生更好的表现。我们推测,这是因为对比目标能更好地传递教师表征中的所有信息,而不仅仅是传递关于条件独立输出类概率的知识。有些令人惊讶的是,对比目标甚至改善了最初提出的提取类概率知识的任务的结果,例如,将大型CIFAR100网络压缩为性能几乎相同的较小网络。我们认为这是因为不同类别概率之间的相关性包含有用的信息,可以规范学习问题。我们的论文在两个主要独立发展的文献之间建立了联系:知识蒸馏和表征学习。这种联系使我们能够利用表征学习的强大方法,显著提高知识蒸馏的SOTA。

我们的贡献是:

1.基于对比的目标,用于在深度网络之间传递知识。

2.模型压缩、跨模态传输和整体蒸馏的应用。

3.对标12种最新蒸馏方法;CRD优于所有其他方法,例如,与原始KD相比,平均相对改善57%(Hinton等人,2015),令人惊讶的是,后者的表现次之。

这是近几年的得分,有使用crd结合其他损失的,可以在一些任务中得到较好表现,不同任务表现不一致,

多教师蒸馏

多教师蒸馏(Multi-Teacher Distillation)是一种知识蒸馏的方法,它通过同时蒸馏多个教师网络的知识来提升学生网络的性能。相比于传统的单一教师蒸馏,多教师蒸馏可以利用不同教师网络的多样性和丰富性,从而获得更全面的知识传递。

在多教师蒸馏中,通常包括一个学生网络(Student Network)和多个教师网络(Teacher Networks)。每个教师网络都是一个独立的模型,具有不同的架构或参数初始化。学生网络通过同时学习多个教师网络的知识来提高自己的性能。

多教师蒸馏的核心思想是将不同教师网络的预测结果作为辅助目标来训练学生网络。具体而言,多教师蒸馏包括以下步骤:

1、教师网络的训练:针对不同的教师网络,使用标准的监督学习方法进行训练,以获得具有丰富知识的教师模型。

2、教师网络的预测:使用已训练好的教师网络对输入样本进行预测,得到多个教师网络的预测结果。

3、学生网络的训练:将教师网络的预测结果作为辅助目标,与真实标签一起用于训练学生网络。通过最小化学生网络的预测与教师网络预测之间的差异,将教师网络的知识传递给学生网络。

4、蒸馏损失函数的定义:通常使用交叉熵损失函数来衡量学生网络的分类性能。同时,为了传递教师网络的知识,可以定义额外的辅助目标损失,如平均软标签损失(Mean Soft Labels Loss)或特定的蒸馏损失函数。

通过多教师蒸馏,学生网络能够从多个教师网络中获得更丰富的知识,并综合各个教师网络的预测结果来提高自己的性能。多教师蒸馏可以增强模型的泛化能力,减少过拟合问题,并在复杂任务中取得更好的性能表现。

好,接下来我们从源码分析;

蒸馏算法源码分析

KD

链接:https://arxiv.org/pdf/1503.02531.pd3f

发表:NIPS14

class DistillKL(nn.Module):"""Distilling the Knowledge in a Neural Network"""def __init__(self, T):super(DistillKL, self).__init__()self.T = T #教师模型指导学生模型的程度(蒸馏温度),值越大,指导程度越高def forward(self, y_s, y_t):p_s = F.log_softmax(y_s/self.T, dim=1)p_t = F.softmax(y_t/self.T, dim=1)#下面就是对两个模型的预测值,做KL散度的分布分析,如果偏差越大,则kl散度算出来的值越大。#p_t表示教师模型的目标值#p_s表示学生模型的预测值loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]return loss

核心就是一个kl_div函数,用于计算学生网络和教师网络的分布差异。输入为学生和教师模型的分类输出,经过温度可控的软化之后进行KL散度计算,简单直接粗暴有效;

FitNet

全称:Fitnets: hints for thin deep nets

链接:https://arxiv.org/pdf/1412.6550.pdf

发表:ICLR 15 Poster

很容易理解,方法使用特征间信息,对中间层进行蒸馏的开山之作,通过将学生网络的feature map扩展到与教师网络的feature map相同尺寸以后,使用均方误差MSE Loss来衡量两者差异

(1)大模型训练,小模型随机初始化

(2)将大模型特征提取器的第H层作为hint,从第一层到第H层的参数对应图(a)中Whint,,选择小模型特征提取器的第G层作为guided,从第一层到第G层对应图(a)中Wguided

(3)两者feature map大小可能不匹配,引入卷积层调整器(Wr)对guided层进行调整,对应图(b)

(4)优化均方损失函数

(5)对预训练好的小模型进行进一步知识蒸馏,对应图

 
class HintLoss(nn.Module):"""Fitnets: hints for thin deep nets, ICLR 2015"""def __init__(self):super(HintLoss, self).__init__()self.crit = nn.MSELoss()  # 在这个类中,初始化函数中使用了nn.MSELoss(),即均方误差损失函数,
用于度量学生网络和教师网络之间的均方误差'''
在前向传播函数中,接收学生网络的中间层表示f_s和教师网络的中间层表示f_t作为输入。
然后使用均方误差损失函数计算它们之间的差异,得到"hint"损失。
'''def forward(self, f_s, f_t):loss = self.crit(f_s, f_t)return loss
class ConvReg(nn.Module):"""Convolutional regression for FitNet 用来对齐T-S某层feature map的特征尺寸 可学"""def __init__(self, s_shape, t_shape, use_relu=True):super(ConvReg, self).__init__()self.use_relu = use_relus_N, s_C, s_H, s_W = s_shapet_N, t_C, t_H, t_W = t_shapeif s_H == 2 * t_H:self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)elif s_H * 2 == t_H:self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)elif s_H >= t_H:self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))else:raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))self.bn = nn.BatchNorm2d(t_C)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)if self.use_relu:return self.relu(self.bn(x))else:return self.bn(x)

损失计算时,就先使用guided 网络处理完,送进fitloss算一次mse即可;

Fitloss 使用的特征维度做监督,效果没有kd好,可能是由于mse或者特征的提取选择不好,可以考虑多使用几个维度的特征监督;

PKT:Probabilistic Knowledge Transfer

全称:Probabilistic Knowledge Transfer for deep representation learning

链接:https://arxiv.org/abs/1803.10837

发表:CoRR18

提出一种概率知识转移方法,引入了互信息来进行建模。该方法具有可跨模态知识转移、无需考虑任务类型、可将手工特征融入网络等的优点。

 

class PKT(nn.Module):"""Probabilistic Knowledge Transfer for deep representation learningCode from author: https://github.com/passalis/probabilistic_kt"""def __init__(self):super(PKT, self).__init__()def forward(self, f_s, f_t):return self.cosine_similarity_loss(f_s, f_t)@staticmethoddef cosine_similarity_loss(output_net, target_net, eps=0.0000001):# Normalize each vector by its normoutput_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))output_net = output_net / (output_net_norm + eps)output_net[output_net != output_net] = 0target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))target_net = target_net / (target_net_norm + eps)target_net[target_net != target_net] = 0# Calculate the cosine similaritymodel_similarity = torch.mm(output_net, output_net.transpose(0, 1))target_similarity = torch.mm(target_net, target_net.transpose(0, 1))# Scale cosine similarity to 0..1model_similarity = (model_similarity + 1.0) / 2.0target_similarity = (target_similarity + 1.0) / 2.0# Transform them into probabilitiesmodel_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)# Calculate the KL-divergenceloss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))return loss

这和PKT方法效果比KD好一些,主要是使用了概率传递学习先将教师和学生的网络输出进行标准化,再将输出的特征信息使用矩阵乘法、概率化方法映射到另一个空间,最后进行KL散度计算,就是在KD的基础上,将网络输出进行非线性映射成一个更简单的空间,监督这个空间下的S-T KL散度

CRD: Contrastive Representation Distillation

全称:Contrastive Representation Distillation

链接:https://arxiv.org/abs/1910.10699v2

发表:ICLR20

将对比学习引入知识蒸馏中,其目标修正为:学习一个表征,让正样本对的教师网络与学生网络尽可能接近,负样本对教师网络与学生网络尽可能远离。

构建的对比学习问题表示如下:

整体的蒸馏Loss表示如下:

实现如下:https://github.com/HobbitLong/RepDistiller

class ContrastLoss(nn.Module):"""contrastive loss, corresponding to Eq (18)"""def __init__(self, n_data):super(ContrastLoss, self).__init__()self.n_data = n_datadef forward(self, x):bsz = x.shape[0]m = x.size(1) - 1# noise distributionPn = 1 / float(self.n_data)# loss for positive pairP_pos = x.select(1, 0)log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()# loss for K negative pairP_neg = x.narrow(1, 1, m)log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bszreturn lossclass CRDLoss(nn.Module):"""CRD Loss functionincludes two symmetric parts:(a) using teacher as anchor, choose positive and negatives over the student side(b) using student as anchor, choose positive and negatives over the teacher sideArgs:opt.s_dim: the dimension of student's featureopt.t_dim: the dimension of teacher's featureopt.feat_dim: the dimension of the projection spaceopt.nce_k: number of negatives paired with each positiveopt.nce_t: the temperatureopt.nce_m: the momentum for updating the memory bufferopt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim"""def __init__(self, opt):super(CRDLoss, self).__init__()self.embed_s = Embed(opt.s_dim, opt.feat_dim)self.embed_t = Embed(opt.t_dim, opt.feat_dim)self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)self.criterion_t = ContrastLoss(opt.n_data)self.criterion_s = ContrastLoss(opt.n_data)def forward(self, f_s, f_t, idx, contrast_idx=None):"""Args:f_s: the feature of student network, size [batch_size, s_dim]f_t: the feature of teacher network, size [batch_size, t_dim]idx: the indices of these positive samples in the dataset, size [batch_size]contrast_idx: the indices of negative samples, size [batch_size, nce_k]Returns:The contrastive loss"""f_s = self.embed_s(f_s)f_t = self.embed_t(f_t)out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)s_loss = self.criterion_s(out_s)t_loss = self.criterion_t(out_t)loss = s_loss + t_lossreturn loss
 

他会在训练过程中,使用contrast-memory 来记忆网络的负样本,在网络训练中互信息监督;效果不错;

超分等生成任务与蒸馏

众所周知,图像/视频超分 (SR) 是工业界非常具有应用场景的应用,但能够生产具有良好视觉效果的重建图像的SR模型的参数量和运算量都非常巨大,比如业界公认的优秀baseline模型EDSR,EDVR等的算力需求高达几百,几千GFLOPs。而业界真正需求的轻量化模型,尤其是可以部署于移动端设备的实时模型,其算力限制可能严苛到小于10GFlops。

在high-level CV tasks上得到广泛应用和验证的模型剪枝、c馏方法应用到超分任务上,即将一个训练好的大模型进行裁剪,或者用性能较强的教师大模型蒸馏原本较弱的学生小模型,使裁剪/蒸馏后的小模型能够取得相比普通训练方式更好,甚至接近原先大模型的性能。这里的challenge在于,直接的迁移应用这些算法,在超分任务上无法得到有效的性能提升,甚至可能导致非常严重的performance degradation.

  • SRKD:它将最基本的知识蒸馏直接应用到图像超分中,整体思想分类网络中的蒸馏方式基本一致,整体来看属于应用形式;

  • FAKD:它在常规知识蒸馏的基础上引入了特征关联机制,进一步提升被蒸馏所得学生网络的性能,相比直接应用有了一定程度的提升;

  • PISR:它则是利用了广义蒸馏的思想进行超分网络的蒸馏,通过充分利用训练过程中HR信息的可获取性进一步提升学生网络的性能。

上图给出了SRKD的蒸馏示意图,它采用了最基本的知识蒸馏思想对老师网络与学生网络的不同阶段特征进行蒸馏。考虑到老师网络与学生网络的通道数可能是不相同的,SRKD则是对中间特征的统计信息进行监督。该文考虑了如下四种统计信息:

owards Compact Single Image Super-Resolution via Contrastive Self-distillation

链接:

code:GitHub - Booooooooooo/CSD: Towards Compact Single Image Super-Resolution via Contrastive Self-distillation, IJCAI21

发表:IJCAI21

团队:Yonsei University

1.背景

卷积神经网络在超分任务上取得了很好的成果,但是依然存在着参数繁重、显存占用大、计算量大的问题,为了解决这些问题,作者提出利用对比自蒸馏实现超分模型的压缩和加速。

我们的目标是同时压缩和加速SR模型。我们提出了一个简单的自蒸馏框架,其中学生网络通过在每层使用教师的部分通道从教师(目标)网络中分离出来。我们将这种学生网络称为信道分割超分辨率网络(CSSRNet)。教师网络和学生网络共同训练,形成两个计算方式不同的SR模型。根据设备中计算资源的不同,我们可以动态分配这两种模型,即在资源有限的设备中,如果超过所需的计算开销,则选择CSSR-Net,否则选择教师模型.

主要贡献

作者提出的对比自蒸馏(CSD)框架可以作为一种通用的方法来同时压缩和加速超分网络,在落地应用中的运行时间也十分友好。

自蒸馏被引用进超分领域来实现模型的加速和压缩,同时作者提出利用对比学习进行有效的知识迁移,从而 进一步的提高学生网络的模型性能。

在Urban100数据集上,加速后的EDSR+可以实现4倍的压缩比例和1.77倍的速度提高,带来的性能损失仅为0.13 dB PSNR。

2.方法

我们的CSD包括两个部分:CSSR-Net和对比损失(CL)。首先,我们描述了CSSR-Net。然后,我们给出了构造CSSR-Net的上界和下界的正则表达式。

最后,给出了CSD方案的总体损失函数,并用一种新的优化策略对其进行了求解。

总结

回顾

近年来,知识蒸馏(Knowledge Distillation)方法在深度学习领域中备受关注,它是一种模型压缩技术,旨在将一个复杂的模型(通常被称为教师模型)的知识转移到一个简化的模型(通常被称为学生模型)中,从而使学生模型能够在保持性能的同时具有更小的模型尺寸和计算成本。

一些近年来的知识蒸馏方法和拓展包括:

  1. Teacher-Student Architecture: 最常见的知识蒸馏方法之一是使用教师模型和学生模型之间的监督信号。教师模型通常是一个大型、复杂的模型,而学生模型则是一个较小、简化的模型。通过让学生模型学习教师模型的输出,学生模型可以在学习到教师模型的知识的同时获得更好的泛化性能。

  2. Soft Target Training: 传统的监督学习使用的是硬标签(one-hot编码),即只有正确类别的概率为1,其余为0。而软目标训练则使用教师模型的输出概率分布作为目标。这种方法能够提供更丰富的信息,使得学生模型可以学习到更多的知识。

  3. Attention Mechanisms: 在知识蒸馏中引入注意力机制可以帮助学生模型更好地关注教师模型的重要信息,从而提高模型性能。

  4. Self-Distillation: 自蒸馏是一种方法,其中学生模型在训练过程中不仅要学习来自教师模型的知识,还要学习自身的输出。这种方法可以进一步提高学生模型的性能,同时减少对教师模型的依赖。

  5. Multi-Teacher Distillation: 多教师蒸馏是一种将多个教师模型的知识融合到学生模型中的方法。每个教师模型可能具有不同的视角或专长,通过结合它们的知识,学生模型可以获得更全面和鲁棒的学习。

未来

随着深度学习模型的不断发展和复杂化,未来的知识蒸馏方法可能会涉及更复杂的模型结构。这可能包括对于更深、更宽的神经网络架构的探索,以及对于更复杂的模型组合和蒸馏技术的研究。例如,结合Transformer模型的自注意力机制与知识蒸馏技术可能会带来更加高效的模型压缩和知识传递方式。

其次,未来的知识蒸馏方法可能会更加注重模型的智能化和个性化。这意味着,蒸馏过程将更加关注于学生模型的个性化需求和特征提取,以及对于不同学习任务和场景的适应性。这可能会涉及到更加精细的目标函数设计、更加智能化的蒸馏策略以及更加灵活的模型结构。

目前有的蒸馏方法效果提升不大,知识蒸馏还有很大提升空间,因为网络中有大量的参数,而实际使用到的很少,所以可以在蒸馏方法上优化,将特征提取和知识传递做得更通用,或者更准确,甚至像大模型的预训练与微调一样,或者是自监督蒸馏,或者是自动地结合上剪枝量化,感知量化等等方法。

reference

1、crd https://arxiv.org/abs/1910.1069

2、crd code https://github.com/HobbitLong/RepDistiller

3、cls kd https://blog.csdn.net/akaweige/article/details/131520764

4、sr kd https://zhuanlan.zhihu.com/p/346422123

5、cls kd https://zhuanlan.zhihu.com/p/102038521

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/13373.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Logic Pro X for Mac v11.0.0激活版:专业音频制作软件

对于音乐创作者来说&#xff0c;一个稳定、高效的工作流程至关重要。Logic Pro X for Mac提供了一系列工作流程优化功能&#xff0c;让你能够更快捷、高效地完成音乐创作。从添加音轨、录制音频&#xff0c;到混音和编曲&#xff0c;每一个步骤都如丝般顺滑。同时&#xff0c;L…

Maven 依赖排查

先从项目去看显而易见&#xff0c;假如我们有一个项目&#xff0c;父工程中包含一些子工程&#xff0c;如下&#xff1a; 我们想看一下samples-account中的依赖关系&#xff0c;那么我们可以打开 samples-account的pom文件&#xff0c;查看其maven依赖关系图。 我们可以看到此项…

ARM 交叉编译搭建SSH

一、源码下载 zlib&#xff1a;zlib-1.3.1.tar.xz openssl&#xff1a;openssl-0.9.8d.tar.gz openssh&#xff1a;openssh-4.6p1.tar.gz 二、交叉编译 1、zlib 编译参考这里 2、openssl tar -xf openssl-0.9.8d.tar.gz ./Configure --prefix/opt/ssh/openssl os/compile…

2024年抖店保证金交多少?保证金常见问题解答,一文解决你所有疑惑

大家好&#xff0c;我是电商花花 新手如果想要开抖音小店&#xff0c;有一个大坑是必须要避开的。 就是我们店铺开通之后&#xff0c;我们一定要交保证金&#xff0c;如果不交&#xff0c;那就是0元开店。 很多新手听别人说做抖音小店可以0元开店&#xff0c;不用缴纳保证金就…

开箱机选型“避坑”指南:风险识别与应对策略一网打尽

在现代化生产线上&#xff0c;开箱机作为关键设备之一&#xff0c;其选型过程的成功与否直接关系到生产效率与成本控制。然而&#xff0c;在选型过程中&#xff0c;往往会面临诸多风险&#xff0c;如何有效识别并应对这些风险&#xff0c;成为企业关注的焦点。星派将为您详细解…

JETBRAINS IDES 分享一个2099通用试用码!DataGrip 2024 版 ,支持一键升级

文章目录 废话不多说上教程&#xff1a;&#xff08;动画教程 图文教程&#xff09;一、动画教程激活 与 升级&#xff08;至最新版本&#xff09; 二、图文教程 &#xff08;推荐&#xff09;Stage 1.下载安装 toolbox-app&#xff08;全家桶管理工具&#xff09;Stage 2 : 下…

百度Comate插件领50京东E卡

给你分享一个AI编码助手——百度Comate&#xff01;扫码参与抽红包活动&#xff0c;520宠粉&#xff01;送京东卡&#xff01;https://url.xffjs.com/sMsP7m 流程如下 点击&#xff1a;点我传送 验证码登录账户 点击个人中心 复制License 去idea或者vscode安装插件 询问一…

【Redis】Redis 主从集群(二)

1.哨兵机制原理 1.1.三个定时任务 Sentinel 维护着三个定时任务以监测 Redis 节点及其它 Sentinel 节点的状态 1&#xff09;info 任务&#xff1a;每个 Sentinel 节点每 10 秒就会向 Redis 集群中的每个节点发送 info 命令&#xff0c;以获得最新的 Redis 拓扑结构 2&#xff…

RabbitMQ的基本组件有哪些?

RabbitMQ的基本组件有哪些&#xff1f; RabbitMQ介绍、解耦、提速、削峰、分发 详解、RabbitMQ安装 可视化界面讲解 RabbitMQ 不生产消息&#xff0c;他是消息的搬运工。 1. Producer: 消息的发布者。 2. Connection:producer/comsumer 和 Message Broker 之间的 TCP 连接。 3…

JavaGUI---JavaFX---未完结

一、Java事件处理机制的应用 JavaFX&#xff1a;JavaFX是Java平台上的一个GUI工具包&#xff0c;它提供了一些内置的事件处理机制。 Swing&#xff1a;Swing是Java平台上的另一个GUI工具包&#xff0c;它也提供了一些内置的事件处理机制。 二、JavaFX和Swing的关键区别&…

20232906 2023-2024-2 《网络与系统攻防技术》第十次作业

20232906 2023-2024-2 《网络与系统攻防技术》第十次作业 1.实验内容 一、SEED SQL注入攻击与防御实验 我们已经创建了一个Web应用程序&#xff0c;并将其托管在http://www.seedlabsqlinjection.com/&#xff08;仅在SEED Ubuntu中可访问&#xff09;。该Web应用程序是一个简…

算法day08

第一题 1. 两数之和 由上述题意所知&#xff0c;本题要采用二分法的解题思路&#xff0c;二分法主要是面向有序的数组且也满足二段性的数组&#xff0c;所谓二段性就是在一定的规则下能把该数组分成两个部分&#xff1b; 本题注意要点&#xff1a; 1、循环结束的条件&#xff…

【Leetcode每日一题】 综合练习 - 括号生成(难度⭐⭐)(76)

1. 题目解析 题目链接&#xff1a;22. 括号生成 这个问题的理解其实相当简单&#xff0c;只需看一下示例&#xff0c;基本就能明白其含义了。 2.算法原理 问题描述 我们需要找出所有可能的、有效的括号序列。一个有效的括号序列指的是一个仅由(和)组成的字符串&#xff0c;…

ssm132医院住院综合服务管理系统设计与开发+vue

医院住院综合服务管理系统的设计与实现 摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对医院住院信息管理混乱&…

【高阶数据结构(四)】图的最短路径问题

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:高阶数据结构专栏⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学习更多数据结构   &#x1f51d;&#x1f51d; 高阶数据结构 1. 前言2. 单源最短…

第八篇 Asciidoc 输出 All In One HTML 解决图片无法显示问题

问题:我的图片显示不出来了 小明使用 Asciidoc 来记笔记,他将笔记输出为 HTML 文件。小丽向小明借笔记。小明将 Asciidoc 笔记输出为 HTML文件,并拷贝给了小丽。 但是,小丽发现,图片都显示不出来了。 小丽:小明,你给我的笔记,图片都显示不出来啊。 小明:是我给你的…

析构函数详解

目录 析构函数概念特性对象的销毁顺序 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 &#x1f412;&#x1f412;&#x1f412; 个人主页 &#x1f978;&#x1f978;&#x1f978; C语言 &#x1f43f;️&#x1f43f;️&#x1f43f;️ C语言例题 &…

yolov8实战之 .pt 转. tensorRT

1 yolo 训练 1.1修改自己的数据集合 我是有3个类别&#xff0c;差不多这么些数据 1.2 训练 from ultralytics import YOLO # Load a model model YOLO("yolov8m.yaml") # build a new model from scratch #model YOLO(E:/pythonCode/pythonProject1/runs/detec…

风电功率预测 | 基于PSO-BP神经网络实现风电功率预测(附matlab完整源码)

风电功率预测 风电功率预测完整代码风电功率预测 基于粒子群优化算法(Particle Swarm Optimization, PSO)的BP神经网络是一种常见的方法,用于实现风电功率预测。下面是一个基于PSO-BP神经网络实现风电功率预测的一般步骤: 数据准备:收集与风电场发电功率相关的数据,包括…