Contents
- Introduction
- Method
- Rethinking Knowledge Distillation for Autoregressive LMs
- Improving Knowledge Distillation with Adaptive Teaching Modes
- Experiments
- References
Introduction
- 作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically
result in a poorer student, especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
Method
Rethinking Knowledge Distillation for Autoregressive LMs
- Reformulation of L K L \mathcal L_{\mathbf {KL}} LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss K L ( p b t ∣ ∣ q b t ) \mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t) KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss K L ( p ^ t ∣ ∣ q ^ t ) \mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t}) KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值 p \ g t t p_{\backslash g_t}^t p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
L K L = ∑ t = 1 T ( p g t t log ( p g t t q g t t ) + ∑ j = 1 , j ≠ g t C p j t log ( p j t q j t ) ) = ∑ t = 1 T ( p g t t log ( p g t t q g t t ) + p \ g t t ∑ j = 1 , j ≠ g t C p ^ j t ( log ( p ^ j t q ^ j t ) + log ( p \ g t t q \ g t t ) ) = ∑ t = 1 T ( p g t t log ( p g t t q g t t ) + p ∖ g t t log ( p ∖ g t t q ∖ g t t ) + p ∖ g t t ∑ j = 1 , j ≠ g t C p ^ j t log ( p ^ j t q ^ j t ) = ∑ t = 1 T ( K L ( p b t ∣ ∣ q b t ) + p \ g t t K L ( p ^ t ∣ ∣ q ^ t ) ) \begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned} LKL=t=1∑T(pgttlog(qgttpgtt)+j=1,j=gt∑Cpjtlog(qjtpjt))=t=1∑T(pgttlog(qgttpgtt) +p\gttj=1,j=gt∑Cp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1∑T(pgttlog(qgttpgtt)+p∖gttlog(q∖gttp∖gtt) +p∖gttj=1,j=gt∑Cp^jtlog(q^jtp^jt)=t=1∑T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中, T T T 为序列长度, p , q p,q p,q 分别为 teacher 和 student 的概率分布, g t gt gt 为 teacher 预测的 ground-truth 类别, p g t t = exp ( z g t t ) ∑ j = 1 C exp ( z j t ) , p ∖ g t t = ∑ k = 1 , k ≠ g t C exp ( z k t ) ∑ j = 1 C exp ( z j t ) , p ^ i t = exp ( z i t ) ∑ j = 1 , j ≠ g t C exp ( z j t ) p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)} pgtt=∑j=1Cexp(zjt)exp(zgtt),p∖gtt=∑j=1Cexp(zjt)∑k=1,k=gtCexp(zkt),p^it=∑j=1,j=gtCexp(zjt)exp(zit), p i t = p ∖ g t t ⋅ p ^ i t p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t pit=p∖gtt⋅p^it, p b t = [ p g t t , p ∖ g t t ] \mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t] pbt=[pgtt,p∖gtt] - Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据 p \ g t t p_{\backslash g_t}^t p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合
(2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重 p \ g t t p_{\backslash g_t}^t p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于 p \ g t t p_{\backslash g_t}^t p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
Improving Knowledge Distillation with Adaptive Teaching Modes
- Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
L K L e = − ∑ t ∈ D e K L ( p ^ t ∣ ∣ q ^ t ) , L K L h = − ∑ t ∈ D h K L ( p b t ∣ ∣ q b t ) + K L ( p ^ t ∣ ∣ q ^ t ) \begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned} LKLe=−t∈De∑KL(p^t∣∣q^t),LKLh=−t∈Dh∑KL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和 L K L a l l = λ ∗ L K L e + ( 1 − λ ) ∗ L K L h \mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h LKLall=λ∗LKLe+(1−λ)∗LKLh其中, λ = 0.2 \lambda=0.2 λ=0.2
Experiments
- Compared Results. S NLG \mathcal S_{\textrm{NLG}} SNLG 为语言生成任务,由 GPT-4 打分; S NLU \mathcal S_{\textrm{NLU}} SNLU 为语言理解任务,为 benchmark 得分
- Ablation Study. (1) Impact of ratio k k k. k k k 用于确定 top- k k k uncertainty 的 tokens 为难样本;(2) Impact of coefficient λ λ λ. 用于确定难易样本损失的权重
References
- Zhong, Qihuang, et al. “Revisiting knowledge distillation for autoregressive language models.” arXiv preprint arXiv:2402.11890 (2024).