文章目录
- Tailoring Instructions to Student’s Learning Levels Boosts Knowledge Distillation
- 一、PPT内容
- P1 Background
- P2 L2T--online distillation
- P3 L2T--Meta distillation
- P4 Approach--LGTM
- P5 Distillation influence
- P6 Finite difference approximation
- P7 Teacher's auxiliary loss
- P8 results
- P9 Analysis of Distillation Influence
- 二、论文泛读
- 2.1 论文要解决什么问题?
- 2.2 论文采用了什么方法?
- 2.4 论文达到什么效果?
- 三、论文精读
- 3.1 模型精讲
- 3.1.1 Revisiting Learning to Teach
- 3.1.1.1 Vanilla distillation
- 3.1.1.2 Online distillation
- 3.1.1.3 Meta distillation
- 3.1.1.4 LGTM
- distillation influence
- Finite difference approximation
- Teacher’s auxiliary loss
- Relationship with other L2T methods
- 3.2 实验分析和讨论
- 3.2.1 实验设置
- 3.2.2 与Meta蒸馏的对比
- 3.2.3 主要结果
- 3.2.4 蒸馏影响的分析
- 3.2.5 消融实验
- 四、总结
- 4.1 关键点
- 4.2 创新点
- 4.3 启发点
- 五、十问十答
- 总结
- 读者可能感兴趣的问题:
- **Q1**论文试图解决什么问题?
- **Q2**这是否是一个新的问题?
- **Q3**这篇文章要验证一个什么科学假设?
- **Q4**有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员?
- **Q5**论文中提到的解决方案之关键是什么?
- **Q6**论文中的实验是如何设计的?
- **Q7**用于定量评估的数据集是什么?代码有没有开源?
- **Q8**论文中的实验及结果有没有很好地支持需要验证的科学假设?
- **Q9**这篇论文到底有什么贡献?
- **Q10**下一步呢?有什么工作可以继续深入?
- 六、相关知识点
- 6.1 Learning to teach(学习到教)
- 6.2 Online Distillation
- 6.3 Meta Distillation
- 6.4 influence function
- 6.5 有限差分
- 6.6 Jacobian Matrix
Tailoring Instructions to Student’s Learning Levels Boosts Knowledge Distillation
一、PPT内容
P1 Background
KD是解决将大模型部署到下游任务时的高昂计算和存储成本的一种方案。
问题:
一些文献表明,一个效果更好的教师模型,不一定能教出更好的学生模型。
这是因为在知识蒸馏的过程中容易产生优化困难的问题,从而导致教师学到的知识不能有效传递给学生。
一种解决这种问题的方式是L2T,通过学生的反馈来调整教师的输出。
P2 L2T–online distillation
在线蒸馏和Meta蒸馏是两种代表性的L2T方法。
Online distillation and meta distillation are two representative L2T methods.
在线蒸馏同时训练教师和学生模型并且增强它们在训练集上输出的相似性。但是,在线蒸馏专注于在训练集上将教师的知识转移给学生,没有明确考虑验证集上的表现。
P3 L2T–Meta distillation
meta distillation考虑学生在保留验证集上的泛化能力,指导教师学习过程,使泛化能力最大化,但是由于教师模型只接受学生模型的监督,其优化目标可能导致教师模型退化。
P4 Approach–LGTM
因此,我们提出了 LGTM(Learning Good Teacher Matters)模型,导出了 distillation influence 的概念,即通过学生在验证集上的输出,评估每个训练样本对其泛化能力的影响,从而动态地分配权重给不同的训练样本。
具体地说:
学生难以泛化的样本,会被给予更低的权重。而教师通过学生的反馈,并结合自身在训练集上学习到的知识,能够动态地调整自身输出,从而给予学生更合适的监督信号。
P5 Distillation influence
Influence function,用来估计每个训练样本对模型预测结果的影响。
在知识蒸馏的场景下,我们可以通过计算每个训练样本和验证集 batch 的梯度相似度,来量化每个训练样本对模型泛化能力的影响。
为了将 distillation influence 引入教师的训练过程,作者提出了 influence loss:
P6 Finite difference approximation
在计算 L i n f l u e n c e L_{influence} Linfluence中的 distillation influence 时,需要逐一地对训练 batch 里的每个样本计算
这一项的梯度,计算效率受限于训练 batch 的大小 B r B^r Br。因此,可以利用 finite difference 技巧,对 influence loss 进行近似,从而提升计算效率:
P7 Teacher’s auxiliary loss
meta distillation 的一个缺陷是忽略了教师自身对训练样本的学习。因此,作者引入了 auxiliary loss。 L t L_t Lt即为最终训练教师的目标函数。 L i n f l u e n c e h a t L_{influence}hat Linfluencehat和 L a u x L_{aux} Laux的结合,表示教师能兼顾学生的反馈和自身的学习:
P8 results
LGTM 模型,在 6 个文本分类数据集达到了 SOTA 效果,证明了本文方法的有效性:
P9 Analysis of Distillation Influence
作者选取了 MRPC 数据集中的两个典型样本,可视化了 distillation influence 在训练过程中的变化:
左图样本的 ground truth 标签是 0,然而教师和学生在一开始一直分类错误该样本,说明这个样本是难样本,如果过于关注对该样本的学习,可能会削弱学生的泛化能力。因此,该样本被赋予了负权重。
右图样本的 ground truth 标签是 1,教师和学生都能分对该样本,说明该样本是较为简单的样本,有助于帮助学生建立决策边界,因此被赋予了正向权重。
二、论文泛读
快速浏览、把握概要
重点:读标题、摘要、结论、所有小标题和图表
2.1 论文要解决什么问题?
为了增强知识蒸馏过程中对教师训练过程的指导。
2.2 论文采用了什么方法?
引入了蒸馏影响来决定在学生泛化能力上每个训练样本的蒸馏影响,提出了LGTM(Learning Good Teachers Matters),一种将蒸馏影响融入到教师训练过程的有效技术。
2.4 论文达到什么效果?
在GLUE的6种文本分类任务上比10种常用蒸馏方法的性能更好。
三、论文精读
选出精华,仔细阅读
目标及效果自测:所读段落是否详细掌握
3.1 模型精讲
研究表明,一个表现更好的老师并不会教出一个表现更好的学生,有研究者将其归因于蒸馏过程中对于优化的挑战,如随着教师和学生模型容量差距的增长,优化过程可能陷入局部最优。
解决KD过程中性能损失的一种方法是通过学生表现的反馈来更新教师模型(被称作L2T, 学习到教)。
在L2T的算法中,在线蒸馏同时训练教师和学生模型并且增强它们在训练集上输出的相似性。但是,在线蒸馏专注于在训练集上将教师的知识转移给学生,没有明确考虑验证集上的表现。
meta distillation考虑学生在保留验证集上的泛化能力,指导教师学习过程,使泛化能力最大化,但是由于教师模型只接受学生模型的监督,其优化目标可能导致教师模型退化。
“Held-out validation set” 是指在模型训练过程中,将一部分数据保留作为验证集,用于衡量模型在未见过的数据上的性能。它是验证模型泛化能力的一种方式。
教师模型应该优先采样在训练过程中能够提高学生泛化能力的样本,从而使学生在被保留的验证集上表现的更好。
本文受影响函数的启发,提出了蒸馏影响来评估在每个训练样本上的蒸馏如何影响学生在验证集上的表现。并从影响函数的角度出发,解释现存的L2T的方法的深层局限性。因为给小批量中的所有训练样本分配相同的权重,现有的L2T的方法经常受到离群值的影响。本文提出LGTM,根据训练样本的蒸馏影响分配训练样本的损失权重。
符号含义:
教师模型: T ( ⋅ , θ t ) T(·,θ_t) T(⋅,θt),学生模型: S ( ⋅ , θ s ) S(·,θ_s) S(⋅,θs), θ t 和 θ s θ_t和θ_s θt和θs 分别为相应的模型参数, η t 和 η s η_t 和 η_s ηt和ηs 为相应的学习率, ∣ t ∣ 和 ∣ s ∣ |t| 和 |s| ∣t∣和∣s∣ 是 θ t 和 θ s θt 和 θs θt和θs的维度:
参数更新前和更新后的时间步分别记为m和m+1,用于跟踪训练过程中参数的变化。
D t r a i n D_{train} Dtrain 代表训练集,一批 B r B^r Br训练样本及其相应的标签称为 z r = ( x r , y r ) z^r = (x^r, y^r) zr=(xr,yr),其中r代表训练,一批当中的一个样本为 z i r z^r_i zir。
D e v a l D_{eval} Deval 代表验证集,一批样本及其相应的标签称为 z e = ( x e , y e ) z^e = (x^e, y^e) ze=(xe,ye),其中e代表验证。
记 f : R k → R n f : R^k → R^n f:Rk→Rn为一个可微函数, v ∈ R k v∈R^k v∈Rk是一个向量,使用 ∂ f / ∂ v ∈ R k × n ∂f/∂v ∈ R^{k×n} ∂f/∂v∈Rk×n表示函数f的Jacobian矩阵,记 ∂ f / ∂ v ∂f/∂v ∂f/∂v为 ▽ v ▽_v ▽v,记 X T X^T XT为X的转置矩阵。
3.1.1 Revisiting Learning to Teach
本文专注于特定任务的蒸馏,故老师模型以无监督方式进行了预训练,学生模型要么是源自部分教师模型,要么也以无监督方式进行了预训练。
3.1.1.1 Vanilla distillation
典型的知识蒸馏是以两阶段过程进行蒸馏。首先,微调老师模型以最大化在特定下游任务上的表现。当老师模型收敛以后,让学生模型在训练集上密切模仿老师模型的输出。每一个小的batch上优化函数为:
学生模型的参数更新:
局限:由于老师模型的参数在蒸馏阶段固定,所以原始蒸馏方法不允许老师根据学生的反馈调整自己的行为。
3.1.1.2 Online distillation
在线蒸馏能够实现在一个阶段同时对老师和学生模型进行微调。
除最小化Ground Truth标签的交叉熵损失外,还通过最小化教师和学生输出之间的交叉熵损失,使教师模型和学生模型的目标分布接近:
训练过程中需要同时迭代教师和学生模型的参数:
通过迭代更新,学生模型能够学习教师模型的学习曲线,提升了学生模型在给定任务上的表现。
局限:在线蒸馏专注于在训练集上将老师的知识传递给学生,没有明确考虑学生模型在看不见的测试数据上的表现。这可能会导致学生模型只记住了训练的例子,但是泛化能力不好。
3.1.1.3 Meta distillation
meta蒸馏考虑了学生的反馈,指导教师模型的优化,最大化学生的泛化能力。
学生的泛化误差通过Ground Truth标签与学生模型对验证集的预测之间的交叉熵损失来衡量:
meta蒸馏将模型的学习过程分为两个阶段:第一个阶段与vanilla 蒸馏相似,在特定任务数据上微调教师模型;第二个阶段为迭代更新教师和学生模型。值得一提的是,与在线蒸馏相比,meta蒸馏在验证集上获取了学生模型的反馈,而不是在训练数据上。
在第二个阶段,学生模型首先通过标准蒸馏过程的最小化蒸馏损失进行更新:
然后通过最小化学生在保留验证集上的损失来优化教师模型,确保学生有更好泛化能力。在这个过程中,教师知识为了知识迁移而训练。
学生模型的参数更新:
教师模型的参数更新:
局限:meta蒸馏的优化会导致教师模型的退化,因为它只接受学生模型的监督。它阻止了教师模型在第二阶段的持续学习和提升,从而阻碍了其泛化能力。
3.1.1.4 LGTM
为了克服上述限制,作者提出他们的L2T框架–LGTM。首先介绍了蒸馏影响,然后介绍基于有限差分近似的有效训练方法。
distillation influence
蒸馏影响用于估计将训练样本加入知识蒸馏过程学生模型在测试数据上表现的变化,即测量训练样本在模型预测上的影响。(影响函数能够衡量样本对模型参数的影响程度,也就是样本的重要性)
通过计算一个特定例子的影响函数,就有可能估计模型的预测会在多大程度上因对该样本的操作而改变。
在原始蒸馏中,对于学生模型,我们将 z i r z^r_i zir的蒸馏影响作为训练样本 z i r z^r_i zir和验证batch z e z^e ze之间的梯度相似度:
这种影响反映了从特定样本中获得的知识的概括程度。因此,教师应该专注于教导学生捕捉具有最高蒸馏影响的训练样本。
为了将每个样本的影响纳入知识蒸馏中,基于蒸馏影响调整每个样本的损失权重。这样,我们能够确定每个样本的相对重要性,并且有助于控制每个样本对学习过程的贡献。被认为对学生的泛化更有利的样本被赋予更高的权重。用如下的损失函数对教师进行训练:
其中 w i = I d i s t i l l ( z i r , z e ) w_i = I_{distill}(z^r_i, z^e) wi=Idistill(zir,ze) 。通过将影响纳入知识蒸馏的损失函数,我们就可以调整训练过程,以更好的适应目标任务的特点。
Finite difference approximation
对于标准的神经网络,经常对一个小批量的 B r B^r Br 训练样本计算一个合并梯度,以提高计算效率。但是在确定每个样本的蒸馏影响时,计算每个样本的梯度 L c e ( T ( x i r ; θ t m ) , S ( x i r ; θ s m ) ) L_{ce}(T(x^r_i;θ^m_t),S(x^r_i;θ^m_s)) Lce(T(xir;θtm),S(xir;θsm)) 会因为因素 B r B^r Br 减慢训练(需要逐一地对batch里每个样本计算 L c e ( T ( x i r ; θ t m ) , S ( x i r ; θ s m ) ) L_{ce}(T(x^r_i;θ^m_t),S(x^r_i;θ^m_s)) Lce(T(xir;θtm),S(xir;θsm))这一项的梯度,对 θ s θ_s θs需要计算 B r B^r Br 次forward和backward,计算效率受限于训练batch的大小 B r B^r Br )。除此之外,一个简单的实现是内存密集型的,因为需要保存
的副本。
为了解决上述难题,作者引入了有限差分,这是一种在数学分析中常用的方法,用来逼近函数在给定点上的导数。
用以下方式对 L i n f l u e n c e L_{influence} Linfluence进行近似:
其中
,ɛ是一个小的标量。所提出的计算有限差分的方法在计算上是高效的,因为对于单个批次,只需要两次 θ s θ_s θs前向传播和一次 θ t θ_t θt的后向传播。
Teacher’s auxiliary loss
为了平衡教师模型自我进化与可转移性之间的权衡,将与Ground Truth有关的损失 L a u x L_{aux} Laux纳入最终的损失(α是损失率):
总的来说,作者所提出的方法能够使教师适应学生的能力,在提高学生泛化能力的同时提供更加个性化的指导。
Relationship with other L2T methods
对于在线蒸馏,它假设所有的训练样本都具有等效的蒸馏影响,教师模型负责降低所有训练样本的迁移难度。
相比之下,meta蒸馏和在线蒸馏之间的关键不同因素在于动态损失权重。这个权重是衡量当前训练批次 z r z^r zr对学生模型泛化能力的蒸馏影响。具体来说,它反映了训练和验证批次梯度之间的梯度的相似性,表明了当前训练批 z r z^r zr 对验证批 z e z^e ze 的影响。应该注意的是,该权值主要作为自适应学习率,根据梯度中的相似程度调整梯度步长。
3.2 实验分析和讨论
3.2.1 实验设置
数据集:GLUE(文本分类任务):MRPC、RTE、SST-2、MNLI、QNLI、QQP。对于MRPC和QQP,作者展示了F1和准确度(Acc),其他数据集展示了Acc。
Baselines:KD, PKD, SKD, TAKD, RCO, DML, ProKT, PESE-KD, meta Distill
训练设置:将BERT-Base蒸馏为一个6层的BERT模型。对于所有的两阶段基线,在每个任务上进行微调。为了公平对比,Meta蒸馏和LGTM都利用了来自验证集的反馈计算蒸馏损失。
3.2.2 与Meta蒸馏的对比
对于图a和图b中的Meta蒸馏,学生模型的验证损失在后面的迭代中逐渐增加,验证精度不断提高,直至稳定。表明学生模型过拟合。一种可能的解释是,过度强调某些训练样本,产生了较高的损失(比如硬样本或者离群值),这对学生模型的泛化能力有负面影响,从而导致了过拟合。
Meta蒸馏和LGTM的关键区别在于,LGTM计算每个样本的蒸馏影响,但是Meta蒸馏对一批样本中的所有样本一视同仁。这样就可以过滤对学生模型的泛化性能有不利影响的样本,从而稳步降低验证损失(图a),并提高验证的准确度(图b)。
对于教师模型,不仅要给学生传授它们现有的知识,而且要积极寻找新的信息和观点来提升自己的理解。LGTM通过引入教师辅助损失,实现了教师模式知识的有效转移。LGTM中教师模型的验证精度不断提高,而Meta蒸馏中教师模型的验证精度下降较快。
3.2.3 主要结果
LGTM表现优于10条基线,其中包括最近的一些比较强的KD方法,证明了本文方法的有效性。
与其他精心设计的训练方法或者损失函数相比,本文提出的方法实现了SOTA的结果。PKD提出了两种蒸馏方案使得学生模型能够从教师模型的多个中间层进行学习,进行增量式知识提取。SKD和DIST都修改了KL散度loss的形式,缩小了教师和学生模型之间的差距。和TAKD和RCO等相比,LGTM也不需要一系列的教师助手模型。
与其他在线蒸馏方法相比,LGTM更优。证明了在训练过程中纳入学生模型的反馈的重要性。过分强调训练集中的知识迁移可能导致学生模型的过拟合,导致其泛化能力下降。
此外,与meta蒸馏方法(如meta蒸馏(Zhou et al., 2022))不同,我们的方法允许计算单个训练样本的蒸馏影响,从而可以过滤掉可能损害学生泛化能力的样本。因此,LGTM能够帮助学生发展对整体任务的一般理解,同时减轻过拟合问题。
3.2.4 蒸馏影响的分析
本文进一步探索了在实际训练过程中样本蒸馏影响的趋势。
本文对MRPC数据集进行了实验。任务是预测句子对中的句子在语义上是否等价(Wang et al., 2018)。
首先,我们选择两个具有代表性的样本,如图所示,将蒸馏影响的趋势和师生预测之间的关系可视化。
在图的左侧,可以看到,在训练的初始阶段,教师(绿色)和学生(橙色)都做出了错误的预测。这可能表明,这个样本对两个模型的学习都提出了重大挑战。在本例中,我们不希望学生模型过多地模仿教师模型的输出,因为这个示例中的教师模型也是错误的。本文的方法能够逐步将loss weight调整为负值,这说明我们暂时会过滤掉这个误导的训练样本,使两个模型学习更快。因此,学生模型首先摆脱了这种困境。然后,通过学生对验证集的反馈,教师模型也学会了做出正确的预测
。最后,随着训练的进行,观察到学生和教师都能够正确地对该样本进行分类,从而使蒸馏影响稳定在接近零的值。
在图中展示了另一个例子,在这个例子中,学生和老师都能够准确地预测一个给定的样本。这可能意味着这个例子对于老师和学生来说都太简单了。在这种情况下,我们想给这个样本一个高的正权重,以形成一个学生友好的决策边界。这类似于课程学习中从简单样本到困难样本的课程设计(Soviany et al., 2022)。
基于随机从MRPC中选择的64个样本,我们还在图中可视化了蒸馏影响的平均趋势。我们观察到,蒸馏的影响通常在训练的开始和结束时不显著,但在训练的中间有波动。这是合理的,因为本文的方法是在训练过程中为每个样本分配不同的权重,目的是过滤困难的样本,并更好地集中于样本进行泛化。
3.2.5 消融实验
有限差分近似:本文引入有限差分近似(FDA)来估计每个样品的蒸馏影响。它的设计是为了解决计算每个样本梯度的缓慢。如表所示,本文在MRPC数据集上进行了消融实验,以评估其有效性。我们发现,在有FDA的情况下,本文的方法只需要11分钟就可以完成培训,而没有FDA的初级培训需要117分钟。如此显著地减少了训练时间(即超过10倍的加速),突显了FDA技术的计算效率。此外,本文评估了MRPC数据集验证集的性能,并观察到使用FDA的训练结果F1得分为90.4,而不使用FDA的训练结果为90.7。在近似的情况下,性能只有轻微的下降。
蒸馏损失:在知识蒸馏的背景下还有其他的蒸馏损失。在这里,本文想评估LGTM是否能够适应这些目标。特别地,本文考虑了DIST (Huang et al., 2022)中使用的修正损失和公共均方误差(MSE)。从表2中可以看出,本文的LGTM始终优于利用这些蒸馏目标的原始方法,这验证了LGTM与不同蒸馏目标的兼容性。
学生模型大小:本文进行了实验,以评估所提出的方法在教师和学生模型之间存在较大容量差异的场景下的性能。具体来说,我们将BERT- base模型(Devlin et al., 2019)的知识蒸馏到4层BERT模型。从表4可以看出,除了在SST-2上的竞争结果外,LGTM在大多数任务上的表现始终优于其他基线。这表明了本文的方法的稳健性,表明了它在各种知识蒸馏设置中的广泛应用。
四、总结
总览全文、归纳总结
总结文中的创新点、关键点、启发点等重要信息
4.1 关键点
- 提出蒸馏影响来评估每个训练样本对学生泛化能力的影响。
- 基于蒸馏影响为每个样本分配损失权重,以调整老师模型的训练过程。
4.2 创新点
1.提出了蒸馏影响来评估每个训练样本对学生泛化能力的影响。
2.引入有限差分近似来有效将蒸馏影响纳入教师的学习过程。
4.3 启发点
- 在知识蒸馏过程中,老师模型不仅要追求自身性能,更要考虑如何有效地传授知识给学生模型。这与人类教学中老师要关注学生的学习需求是类似的。
- 不仅要考虑训练数据的整体分布,也要注意单个样本对模型训练的影响。有些样本作为异常点或错误标记可能会对模型泛化能力产生不利影响。
- 老师模型可以通过动态调整每个样本的损失权重,来更有针对性地指导学生模型的学习过程。这类似于人类教学中根据学生的情况设计有效的教学策略。
- 学生模型的反馈对优化老师模型也很重要。老师模型在自我进步的同时,也要积极吸收学生模型的反馈,进行自我调整。
- 相比直接优化老师模型性能,优化老师模型的“教学效果”可能更有助于知识的有效传递。这使我联想到人类教育中“教学效果”也很重要。
五、十问十答
总结
本文主要介绍了一种新的知识蒸馏方法——Learning Good Teacher Matters (LGTM),该方法通过引入蒸馏影响来确定每个训练样本对学生模型泛化能力的影响,从而让教师模型在训练过程中更加注重那些对学生模型泛化有益的样本。实验证明,LGTM方法在GLUE数据集中的6个文本分类任务中比其他10种常见的知识蒸馏方法表现更好。
读者可能感兴趣的问题:
- 什么是知识蒸馏,为什么要进行知识蒸馏?
- LGTM方法是如何确定蒸馏影响的,与其他知识蒸馏方法相比有什么优势?
- LGTM方法在GLUE数据集中的6个文本分类任务中表现更好,那么在其他任务中表现如何?是否有进一步的优化空间?
Q1论文试图解决什么问题?
这篇论文试图解决知识蒸馏过程中,老师模型的训练与有效的知识传递之间的不匹配问题。即使老师模型性能更好,也不一定能训练出更强的学生模型,这突出了当前老师模型训练方式与有效知识传递之间的差异。
Q2这是否是一个新的问题?
这不是一个全新的问题。之前的研究已经观察到了老师性能更好不一定带来更好学生模型的问题。这篇论文提供了新的视角来解决这个已有的问题。
但是又可以说是一个新的问题,在传统的知识传递方法中,通常是先进行模型的训练,然后再通过掩模或其他方式将模型的知识传递给下一个模型。而在LGTM中,同时进行教师模型和学生模型的训练,使得教师模型可以根据学生模型的反馈来更好地传递知识。因此,这个问题可以说是是一个新的问题。
Q3这篇文章要验证一个什么科学假设?
这篇文章要验证通过考虑每个训练样本对学生模型泛化能力的影响来指导老师模型的训练过程,可以增强知识蒸馏的效果。
Q4有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员?
相关研究包括:
(1) 在线蒸馏:允许老师模型结合学生反馈进行调整。
(2) 元蒸馏:利用验证集上的学生反馈来指导老师模型,以最大化学生的泛化能力。
(3) 影响函数:测量训练样本对模型预测的影响。
online distillation 和 meta distillation 是 learning to teach 两种有代表性的方法。然而,这两种方法都有不足之处。
online distillation聚焦于学生在训练集上的反馈,而忽略了学生在验证集上的反馈,可能会削弱学生的泛化能力;
meta distillation虽然引入了学生在验证集上的反馈,但却忽略了教师自身在训练集上的学习,仅依靠学生的反馈调整教师的输出,容易导致教师的性能变差。
在线蒸馏和元蒸馏的工作属于学习到教(learning to teach) Paradigm。这篇论文的方法建立在影响函数的思想上。
Q5论文中提到的解决方案之关键是什么?
这篇论文的关键解决方案是:
- 提出蒸馏影响来评估每个训练样本对学生泛化能力的影响。
- 基于蒸馏影响为每个样本分配损失权重,以调整老师模型的训练过程。
Q6论文中的实验是如何设计的?
这篇论文的实验是这样设计的:
- 在GLUE基准测试集上的6个文本分类任务上,与10种常见的知识蒸馏基线进行比较。
- 分析样本的蒸馏影响在训练过程中的变化趋势。
- 对方法的关键组成部分进行消融实验。
Q7用于定量评估的数据集是什么?代码有没有开源?
用于评估的是GLUE基准测试集上的文本分类任务,包括MRPC, RTE, SST-2, MNLI, QNLI和QQP。
代码可以在https://github.com/yannvictor/LGTM 上找到。
Q8论文中的实验及结果有没有很好地支持需要验证的科学假设?
实验结果显示该方法在所有任务上均优于各种基线,说明分配基于蒸馏影响的样本损失权重确实能增强知识蒸馏的效果。实验结果与需要验证的假设是一致的。
Q9这篇论文到底有什么贡献?
这篇论文的主要贡献是:
- 提出蒸馏影响的概念来衡量每个样本的重要性。
- 基于蒸馏影响指导老师模型的训练,实现更有效的知识传递。
Q10下一步呢?有什么工作可以继续深入?
后续的工作可以考虑:
- 在预训练知识蒸馏中应用该方法。
- 在更复杂的文本生成任务中验证该方法。
- 探索其他方式利用蒸馏影响指导老师模型。
- 可以进行更多的实验研究,例如更多的学习策略和更多的基线方法的比较,以更好地理解LGTM的优点和局限。
六、相关知识点
6.1 Learning to teach(学习到教)
出处:Learning to Teach. Yang Fan et. al. ICLR 2018
Learning to Teach 从人类社会的教育系统启发而来,在L2T的框架中存在两个agent: Teacher model & Student model。其中student model即传统ML方法所关注的模型,teacher model则是一个单独的agent,负责向目标模型提供合适的输入来指导其训练过程。
“合适的输入”的评价标准是能使student model在整个task上达到更好的performance / 更高的efficiency。(achieves lower risk R(ω) or progresses as fast as possible)
形式化定义:
f w ( x ) f_w(x) fw(x) 是参数为 w w w 的决策函数, M ( , ) M(,) M(,) 是对该任务选择的评价标准, P ( x , y ) = P ( x ) P ( y ∣ x ) P(x, y)=P(x)P(y|x) P(x,y)=P(x)P(y∣x) 是联合概率分布。
μ为训练所用的学习算法,L为Loss函数,D为训练集,Ω为假设空间。即根据输入(D,L,Ω)使用学习算法μ来优化参数 w w w 使得决策函数在整个训练集D上的Loss最小。
由此可以定义Teacher Model的训练目标φ:
Teacher Model不断优化其向student Model提供的训练集/损失函数/假设空间,使得student model在训练集上表现最优。
6.2 Online Distillation
出处:LARGE SCALE DISTRIBUTED NEURAL NETWORK TRAINING THROUGH ONLINE DISTILLATION.Geoffrey E. Hinton.2020
图中,教师模型和学生模型都是to be trained的状态,即教师模型没有被预训练。
在大容量教师网络没有现成模型的时候,可以考虑使用online distillation。使用在线蒸馏的时候,教师网络和学生网络的参数会同时更新,整个知识蒸馏框架是端到端训练的。
为了克服离线蒸馏的局限性,提出了在线蒸馏以进一步改善学生模型的性能,特别是在没有大容量高性能教师模型的情况下。在在线蒸馏中,教师模型和学生模型同时更新,并且整个知识蒸馏框架是端到端可训练的。
在最近三年中,已经提出了多种在线知识蒸馏方法。具体来说,在深度相互学习中(Zhang等人,2018b),多个神经网络以协作方式工作。在训练过程中,任何一个网络都可以作为学生模型,其他模型可以作为老师。为了提高泛化能力,通过使用 soft Logits 的集合来扩展深度相互学习(Guo等,2020)。 Chen等。 (2020a)进一步将辅助同伴(auxiliary peers)和小组负责人(group leader)引入深度相互学习中,以形成一套多样化的同伴模型。为了降低计算成本,Zhu和Gong(2018)提出了一种多分支架构,其中每个分支表示一个学生模型,不同分支共享相同的骨干网络。 Kim等人(2019b)没有使用Logits,引入了特征融合模块来构建教师分类器。谢等。 (2019)用便宜的卷积运算代替了卷积层以形成学生模型。 Anil等。 (2018)使用在线蒸馏来训练大规模分布式神经网络,并提出了在线蒸馏的一种变体,称为共蒸馏。并行共蒸馏以相同的架构训练多个模型,并且通过从其他模型转移知识来训练任何一个模型。最近,提出了一种在线对抗知识蒸馏方法,以利用来自类别概率和特征图的知识,同时由鉴别者训练多个网络(Chung等,2020)。
在线蒸馏是一种具有高效并行计算功能的单阶段端到端训练方案。然而,现有的在线方法(例如,相互学习)通常不能解决在线设置中的高能力教师,这使得在在线设置中进一步探索教师与学生模型之间的关系成为一个有趣的话题。
离线蒸馏(Offline Distillation)
大多数以前的知识蒸馏方法都可以脱机工作。在常见的知识蒸馏中,知识从预先训练的教师模型转移到学生模型。因此,整个训练过程有两个阶段,即:1.大型教师模型是在蒸馏之前首先在一组训练样本上训练的。
2.教师模型用于提取logit或中间特征形式的知识,然后用于指导蒸馏过程中学生模型的训练。
离线蒸馏的第一阶段通常不作为知识蒸馏的一部分进行讨论,即,假定教师模型是预先定义的。很少关注教师模型结构及其与学生模型的关系。因此,离线方法主要集中于改进知识转移的不同部分,包括知识的设计以及用于匹配特征或分布匹配的损失函数。离线方法的主要优点在于它们简单易行。例如,教师模型可以包含使用可能位于不同机器上的不同软件包训练的一组模型。可以提取知识并将其存储在缓存中。
离线蒸馏方法通常采用单向知识转移和两阶段训练程序。然而,不可避免的是,复杂的高容量教师模型具有很长的训练时间,而离线蒸馏中对学生模型的训练通常在教师模型的指导下是有效的。此外,大型教师和小型学生之间的能力差距始终存在,而学生在很大程度上依赖于教师。
6.3 Meta Distillation
出处:Meta-KD: A Meta Knowledge Distillation Framework for Language Model Compression across Domains. Haojie Pan, Chengyu Wang, Minghui Qiu, Yichang Zhang, Yaliang Li, Jun Huang, CoRR abs/2012.01266
PAI团队提出了MetaKD,即Meta Knowledge Distillation,用元知识蒸馏,可以将跨领域的可迁移知识学出,在蒸馏阶段额外对可迁移的知识进行蒸馏。这样的做法使得学习到的Student模型在相应的领域的效果显著提升,我们在多个跨领域的任务上都蒸馏出了较好的学生模型,逼近教师模型的效果。
如下图所示,MetaKD包括meta teacher learning和meta distillation两个阶段,前者训练出一个跨领域的meta teacher,后者在目标任务里用meta teacher指导学生网络学习。
作者提出了 pilot update 方法来解决这一问题:首先仍然按照 meta learning 的方式进行每一轮的内学习器和元学习器的更新,只不过内学习器的更新要在元学习器更新后撤销,再根据更新的元学习器进行本轮的更新,这样可以使两个学习器的更新进行同步和匹配。在实现撤销这一步时,为了方便起见,作者复制了一个内学习器的副本,在 meta learning 结束后,直接将该副本删除即可。
对于学生模型,其损失函数包括在下游任务上的损失 以及蒸馏损失 :
而对于教师模型,作者希望其能根据学生模型的表现反馈进行调整,而学生模型在下游任务上的表现也包含教师模型参数,所以作者采用下游任务损失作为教师模型的损失函数。特别地,为了防止过拟合,作者从训练集中分离出了一部分样本组成 quiz set,将学生模型在该数据集上的表现计算损失函数。
6.4 influence function
6.5 有限差分
有限差分法(Finite Difference Method,FDM)是一种求解微分方程数值解的近似方法,其主要原理是对微分方程中的微分项进行直接差分近似,从而将微分方程转化为代数方程组求解。
有限差分法是一种求解微分方程的数值方法,其面对的对象是微分方程,包括常微分方程和偏微分方程。此外,有限差分法需要对微分进行近似,这里的近似采取的是离散近似,使用某一点周围点的函数值近似表示该点的微分。
6.6 Jacobian Matrix
文献来源:Tailoring Instructions to Student’s Learning Levels Boosts Knowledge Distillation
参考资料:ACL 2023 | 为学生模型的学习水平量身定制指导,促进知识蒸馏的效果