[ACL 2024] Revisiting Knowledge Distillation for Autoregressive Language Models

Contents

  • Introduction
  • Method
    • Rethinking Knowledge Distillation for Autoregressive LMs
    • Improving Knowledge Distillation with Adaptive Teaching Modes
  • Experiments
  • References

Introduction

  • 作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically
    result in a poorer student
    , especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
    在这里插入图片描述

Method

Rethinking Knowledge Distillation for Autoregressive LMs

  • Reformulation of L K L \mathcal L_{\mathbf {KL}} LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss K L ( p b t ∣ ∣ q b t ) \mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t) KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss K L ( p ^ t ∣ ∣ q ^ t ) \mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t}) KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值 p \ g t t p_{\backslash g_t}^t p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
    L K L = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + ∑ j = 1 , j ≠ g t C p j t log ⁡ ( p j t q j t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p \ g t t ∑ j = 1 , j ≠ g t C p ^ j t ( log ⁡ ( p ^ j t q ^ j t ) + log ⁡ ( p \ g t t q \ g t t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p ∖ g t t log ⁡ ( p ∖ g t t q ∖ g t t ) + p ∖ g t t ∑ j = 1 , j ≠ g t C p ^ j t log ⁡ ( p ^ j t q ^ j t ) = ∑ t = 1 T ( K L ( p b t ∣ ∣ q b t ) + p \ g t t K L ( p ^ t ∣ ∣ q ^ t ) ) \begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned} LKL=t=1T(pgttlog(qgttpgtt)+j=1,j=gtCpjtlog(qjtpjt))=t=1T(pgttlog(qgttpgtt)     +p\gttj=1,j=gtCp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1T(pgttlog(qgttpgtt)+pgttlog(qgttpgtt)     +pgttj=1,j=gtCp^jtlog(q^jtp^jt)=t=1T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中, T T T 为序列长度, p , q p,q p,q 分别为 teacher 和 student 的概率分布, g t gt gt 为 teacher 预测的 ground-truth 类别, p g t t = exp ⁡ ( z g t t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ∖ g t t = ∑ k = 1 , k ≠ g t C exp ⁡ ( z k t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ^ i t = exp ⁡ ( z i t ) ∑ j = 1 , j ≠ g t C exp ⁡ ( z j t ) p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)} pgtt=j=1Cexp(zjt)exp(zgtt),pgtt=j=1Cexp(zjt)k=1,k=gtCexp(zkt),p^it=j=1,j=gtCexp(zjt)exp(zit) p i t = p ∖ g t t ⋅ p ^ i t p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t pit=pgttp^it p b t = [ p g t t , p ∖ g t t ] \mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t] pbt=[pgtt,pgtt]
  • Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据 p \ g t t p_{\backslash g_t}^t p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合
    在这里插入图片描述(2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重 p \ g t t p_{\backslash g_t}^t p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于 p \ g t t p_{\backslash g_t}^t p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因在这里插入图片描述在这里插入图片描述(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
    在这里插入图片描述

Improving Knowledge Distillation with Adaptive Teaching Modes

  • Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
    L K L e = − ∑ t ∈ D e K L ( p ^ t ∣ ∣ q ^ t ) , L K L h = − ∑ t ∈ D h K L ( p b t ∣ ∣ q b t ) + K L ( p ^ t ∣ ∣ q ^ t ) \begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned} LKLe=tDeKL(p^t∣∣q^t),LKLh=tDhKL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和 L K L a l l = λ ∗ L K L e + ( 1 − λ ) ∗ L K L h \mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h LKLall=λLKLe+(1λ)LKLh其中, λ = 0.2 \lambda=0.2 λ=0.2

Experiments

  • Compared Results. S NLG \mathcal S_{\textrm{NLG}} SNLG 为语言生成任务,由 GPT-4 打分; S NLU \mathcal S_{\textrm{NLU}} SNLU 为语言理解任务,为 benchmark 得分
    在这里插入图片描述在这里插入图片描述
  • Ablation Study. (1) Impact of ratio k k k. k k k 用于确定 top- k k k uncertainty 的 tokens 为难样本;(2) Impact of coefficient λ λ λ. 用于确定难易样本损失的权重
    在这里插入图片描述

References

  • Zhong, Qihuang, et al. “Revisiting knowledge distillation for autoregressive language models.” arXiv preprint arXiv:2402.11890 (2024).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/877627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java之类和对象的介绍

1.面向对象和面向过程的概念: 面向对象:面向对象是解决问题的一种思想,主要依靠对象之间的交互完成一件事。 面向过程:注重完成一件事情的过程,后续代码维护扩展较为麻烦。 以洗衣服为例,面向对象为传统…

vue3基础ref,reactive,toRef ,toRefs 使用和理解

文章目录 一. ref基本用法在模板中使用ref 与 reactive 的区别使用场景 二. reactive基本用法在模板中使用reactive 与 ref 的区别使用场景性能优化 三. toRef基本用法示例在组件中的应用主要用途对比 ref 和 toRef 四. toRefs基本用法示例在组件中的应用主要用途对比 ref 和 t…

基于UE5和ROS2的激光雷达+深度RGBD相机小车的仿真指南(一)---UnrealCV获取深度+分割图像

前言 本系列教程旨在使用UE5配置一个具备激光雷达深度摄像机的仿真小车,并使用通过跨平台的方式进行ROS2和UE5仿真的通讯,达到小车自主导航的目的。本教程使用的环境: ubuntu 22.04 ros2 humblewindows11 UE5.4.3python8 本系列教程将涉及以…

二叉树中的奇偶树问题

目录 一题目: 二思路汇总: 1.二叉树层序遍历: 1.1题目介绍: 1.2 解答代码(c版): 1.3 解答代码(c版): 1.4 小结一下: 2.奇偶树分析&#xf…

推荐一个开源的kafka可视化客户端GUI工具(Kafka King)

大佬的博客地址: https://blog.ysboke.cn/posts/tools/kafka-king Github地址: https://github.com/Bronya0/Kafka-King Kafka-King功能清单 查看集群节点列表(完成)支持PLAINTEXT、SASL PLAINTEXT用户名密码认证(完…

Python 如何创建和解析 XML 文件

XML(可扩展标记语言)是一种广泛使用的标记语言,主要用于存储和传输数据。它具有结构化、层次化的特点,常被用作数据交换格式。Python 提供了多种工具和库来处理 XML 文件,包括创建、解析和操作 XML 文档。 一、XML 简…

qt-13 进度条(模态和非模态)

进度条-模态和非模态 progressdlg.hprogressdlg.cppmain.cpp运行图模态非模态 progressdlg.h #ifndef PROGRESSDLG_H #define PROGRESSDLG_H#include <QDialog> #include <QLabel> #include <QLineEdit> #include <QProgressBar> #include <QCombo…

人物形象设计:塑造独特角色的指南

引言 人物形象设计是一种创意过程&#xff0c;它利用强大的设计工具&#xff0c;通过视觉和叙述元素塑造角色的外在特征和内在性格。这种设计不仅赋予角色以生命&#xff0c;还帮助观众或读者在心理层面上与角色建立联系。人物形象设计的重要性在于它能够增强故事的吸引力和说…

p8 Run的流程和Docker原理

docker run的运行原理图 docker是怎么工作的&#xff1f; docker是一个cs的一个结构的系统docker的守护进程运行在宿主机上面通过socket进行访问 其实就是看下面的这个图&#xff0c;通过客户端的命令来操作docker的守护进程然后启动一些容器&#xff0c;默认容器是不启动的 …

网络基础概念【网络】

文章目录 网络协议协议分层 OSI七层模型TCP/IP五层&#xff08;或四层&#xff09;模型同局域网的两台主机通信数据包封装和解包分用&#xff08;数据段&#xff0c;数据报&#xff0c;数据帧&#xff09;网络中的地址管理 网络协议 协议分层 网络协议栈设计成层状结构&#…

【学习笔记】Day 20

一、进度概述 1、机器学习常识12-18&#xff0c;以及相关代码复现 二、详情 12、SVM&#xff08;support vector machines&#xff0c;支持向量机&#xff09; 实际上&#xff0c;支持向量机是一种二分类模型&#xff0c;它将实例的特征向量映射为空间中的一些点&#xff0c;…

如何将CSDN文章导出为pdf文件

第一步&#xff1a; 打开想要导出的页面&#xff0c;空白处点击鼠标右键⇒点击“检查”或“check”&#xff0c;或直接在页面按F12键。 第二步&#xff1a; 复制以下代码粘贴到控制台&#xff0c;并按回车。 若提示让输入“允许粘贴”或“allow pasting”&#xff0c;按提示…

百度地图路书实现历史轨迹回放、轨迹回放进度、聚合点、自定义弹框和实时监控视频、多路视频轮巡播放

前言 分享一个刚做完项目集成技术&#xff0c;一个车辆行驶轨迹监控、行车视频监控、对特种车辆安全监管平台&#xff0c;今年政府单位有很多监管平台项目&#xff0c;例如&#xff1a;渣土车监控、租出车监管、危害气体运输车监管等平台&#xff0c;这些平台都有车辆行驶轨迹…

Linux基础知识学习(五)

1. 用户组管理 每个用户都有一个用户组&#xff0c;系统可以对一个用户组中的所有用户进行集中管理&#xff08;开发、测试、运维、root&#xff09;。不同Linux 系统对用户组的规定有所不同&#xff0c;如Linux下的用户属于与它同名的用户组&#xff0c;这个用户组在创建用户…

QT聊天室基于Tcp

server.cpp #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget),server(new QTcpServer(this)) // 给服务器指针对象实例化空间{ui->setupUi(this); }Widget::~Widget() {delete ui; }…

音频采集spring_ws_webrtc (html采集麦克风转gb711并发送广播播放)完整案例

下载地址&#xff1a;http://www.gxcode.top/code 项目说明 springbootwebscoektwebrtc 项目通过前端webrtc采集麦克风声音&#xff0c;通过websocket发送后台&#xff0c;然后处理成g711-alaw字节数据发生给广播UDP并播放。 后台处理项目使用线程池(5个线程)接受webrtc数据并…

mac如何恢复被同名替换掉的文件夹 mac文件被替换如何恢复

Mac系统一直以高性能遥遥领先其他的Windows系统&#xff0c;因此&#xff0c;Mac虽然价格远远高出其他的笔记本电脑&#xff0c;但是还是受到了一众用户的青睐。使用mac时&#xff0c;我们也经常会将一个文件命名为已经有了相同文件的文件名&#xff0c;且保存到同一个目标地址…

MATLAB-PSO-BiTCN-BiLSTM-Attention多变量分类

一、数据集 数据特征&#xff1a;12个多分类&#xff1a;4分类 ​ 二、PSO-BiTCN-BiLSTM-Attention网络 PSO-BiTCN-BiLSTM-Attention 网络是一种结合了多种深度学习技术和优化算法的复杂模型&#xff0c;用于处理时序数据任务&#xff0c;如时间序列预测、分类或其他相关问题…

【Linux】——进程概念(万字解读)

一 冯诺依曼体系结构 在此之前&#xff0c;我们先要理解我们计算机的冯诺依曼体系结构&#xff0c;因为是进程的基础 我们所有的操作其实都是基于这样一个模型&#xff0c;比如你在qq上&#xff0c;和别人发送消息&#xff0c;这个消息肯定是先通过输入设备进行输入&#xf…

一个注解轻松搞定审计日志服务!

【审计日志】&#xff0c;简单的说就是系统需要记录谁&#xff0c;在什么时间&#xff0c;对什么数据&#xff0c;做了什么样的更改&#xff01;任何一个 IT 系统&#xff0c;如果要过审&#xff0c;这项任务基本上也是必审项&#xff01; 实现【审计日志】这个需求&#xff0…