1 缘起
最近要准备升级材料,里面有一骨碌是介绍LLM相关技术的,知识蒸馏就是其中一个点,
不过,只分享了蒸馏过程,没有讲述来龙去脉,比如没有讲解Softmax为什么引入T、损失函数为什么使用KL散度,想再进一步整理细节部分分享出来,这是其一。
其二是,近3个月没有写文章了,重拾笔头。
今年已经制定写作计划,后面会陆续分享出来。
2 原理
2.1 简介
知识蒸馏是一种模型压缩方法,其中一个小模型被训练来模仿一个预训练的大模型(或模型集合)。这种训练设置有时被称为“教师-学生”模式,其中大模型是教师,小模型是学生。
该方法最早由Bucila等人在2006年提出,并由Hinton等人在2015年进行了推广。Distiller中的实现基于后者的论文。在这里,我们将提供该方法的概述。更多信息,可以参考该论文https://arxiv.org/abs/1503.02531。
2.2 为什么引入T
在蒸馏过程中,直接计算教师模型概率分布,即教师模型上的softmax函数的输出,然而,在许多情况下,这个概率分布中正确类别的概率非常高,而其他类别的概率非常接近于0,导致学生模型学习到的信息并没有提供比数据集中已经提供的真实标签更多的信息。Softmax如下:
p i = e z i ∑ j N e z j p_{i}= \frac{e^{z_{i} } }{\sum_{j}^{N}e^{z_{j} } } pi=∑jNezjezi
为了解决这个问题,Hinton等人在2015年引入了“softmax温度”的概念。类别i的概率pi从logits z计算得出,公式如下:
p i = e z i T ∑ j N e z j T p_{i}= \frac{e^{\frac{z_{i} }{T} } }{\sum_{j}^{N}e^{\frac{z_{j} }{T}} } pi=∑jNeTzjeTzi
其中,T是温度参数,用于控制概率分布的平滑程度。当T较高时,概率分布会更加平滑,从而提供更多的信息,有助于学生模型更好地学习教师模型的知识。
2.3 蒸馏过程
为提升学生模型的性能,学些到更多的信息,引入T,最终的蒸馏过程如下图所示,知识蒸馏有两个Loss,即学生模型与教师模型的 L o s s d i s t i l l a t i o n Loss_{distillation} Lossdistillation,学生模型与真实值的 L o s s s t u d e n t Loss_{student} Lossstudent,其中,教师模型的预测值为软标签,学生模型温度t的的预测值为软预测值,学生模型T=1的预测值为硬预测值。
学生损失函数:
L o s s s t u d e n t = − ∑ i = 1 N y i l o g q i ( 1 ) Loss_{student}=-\sum_{i=1}^{N}y_{i}logq_{i}^{(1)} Lossstudent=−i=1∑Nyilogqi(1)
蒸馏损失函数:
L o s s d i s t i l l a t i o n = − t 2 ∑ i = 1 N p i ( t ) l o g q i ( t ) Loss_{distillation}=-t^{2} \sum_{i=1}^{N}p_{i}^{(t)}logq_{i}^{(t)} Lossdistillation=−t2i=1∑Npi(t)logqi(t)
最终损失函数:
L o s s t o t a l = ( 1 − α ) L o s s s t u d e n t + α L o s s d i s i l l a t i o n Loss_{total}=(1-\alpha )Loss_{student}+\alpha Loss_{disillation} Losstotal=(1−α)Lossstudent+αLossdisillation
2.4 LLM蒸馏损失函数
LLM蒸馏损失函数使用KL散度。
L o s s d i s t i l l a t i o n − L L M = − t 2 K L ( p ( t ) ∣ ∣ q ( t ) ) = − t 2 ∑ i = 1 N p i ( t ) l o g p i ( t ) q i ( t ) Loss_{distillation-LLM}=-t^{2}KL(p^{(t)}||q^{(t)})=-t^{2}\sum_{i=1}^{N}p_{i}^{(t)}log\frac{p_{i}^{(t)} }{q_{i}^{(t)} } Lossdistillation−LLM=−t2KL(p(t)∣∣q(t))=−t2i=1∑Npi(t)logqi(t)pi(t)
2.4.1 为什么使用KL散度
KL散度的概念来源于概率论和信息论中。KL散度又被称为:相对熵、互熵、鉴别信息、Kullback熵、Kullback-Leible散度(即KL散度的简写)。KL 散度比交叉熵更适合作为蒸馏损失,因为当学生模型完美匹配教师模型时,蒸馏损失会为零,而交叉熵却不为零,直接使用交叉熵作为蒸馏损失可能会导致损失随着 batch 波动,所以使用KL散度作为蒸馏损失函数。
3 小结
(1)Softmax引入T用于计算学生和老师预测值概率分布。
(2)使用KL散度计算学生与老师损失;
(3)T=1计算学生模型预测值概率分布。
4 参考
https://zhuanlan.zhihu.com/p/692216196
https://blog.csdn.net/keeppractice/article/details/145419077
https://intellabs.github.io/distiller/knowledge_distillation.html
https://hsinjhao.github.io/2019/05/22/KL-DivergenceIntroduction/