知识蒸馏Matching logits与RocketQAv2

知识蒸馏Matching logits

公式推导

刚开始的\frac{\partial L}{\partial z_i}=q_i-p_i怎么来,可以转看下面证明梯度等于输出值-标签y

C是一个交叉熵,我们要求解的是这个交叉熵对z_i的这个梯度。z_i就是你可以理解成第i个类别的得分。z_i就是student model,被蒸馏的模型,它所输出的logits。

p_i是什么?是target probability对吧。q_i是什么?q_i认为就是这个distilled model的输出的那个probability。所以就是说这两个概率相减,再乘以这个T分之一T是什么?T是一个温度。

我们现在假定是说我们是用teacher model输出的这个label,然后去训练student model,或者说去训练distilled model。我们对这个第i个类别的梯度,就等于\frac{1}{T}{(q_i-p_i)},然后呢,q_ip_i可以做一个化简。

q_ip_i进行展开,概率都是用softmax算出来的,就可以得到这个式子。

通过e^x\approx 1+x来进行化简,这个式子在x比较小的时候是成立的。

在这里,当T足够大的时(相比z的logits,即z),\frac{z_i}{T}就足够的小,接近于0,此时e^{\frac{z_i}{T}}\approx 1+\frac{z_i}{T}

\sum_j e^{\frac{z_i}{T}}\approx \sum_j{1+\frac{z_i}{T}}=N+\sum_j{\frac{z_i}{T}} 

z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0,所以这两个分母直接就变成了N。

\frac{1}{T}({\frac{1+z_i/T}{N}}-{\frac{1+v_i/T}{N}})=\frac{1}{TN}{\frac{z_i-v_i}{T}}

则所求梯度

想说明的事情

它其实就想说明这样一个事情。我们试图用一个teacher model,或者说我们想用VI对应的那个概率叫p_iz_i对应的概率叫q_i。如果我们想用这个p_i作为label去用交叉商去训练q_i去用这个soft label的交叉商去训练q_i,那么其实我们可能不需要套用交叉商这个东西了,我们也不需要什么softmax的label的交叉商,然后去做这个事了。因为这个东西在我们的这样一通推导下就会发现,其实就等于均方误差,右边这一项其实就是什么均方误差的求导,它就是均方误差求导之后的结果,你可以这样认为。

我们就会发现说,原来对于交叉商对于这个知识蒸馏的这个交叉商,然后我们对他求导求出来的梯度其实是近似等同于我们直接用MSE去训练,然后得到的梯度的。那么既然这样,我们为什么不直接用MSE?

它的推导就告诉我们说我们对于两个模型,两个多分类模型来说,我们要用a模型去交B模型做蒸馏。我们没有必要让这两个模型生成分别生成什么label,然后再生成预测的概率,然后再加上去优化了。

我们直接让这两个多分类模型的这个logic,然后直接做MSE就可以了,就可以做到一种就是一种这种MSE就是一种什么蒸馏的特殊形式。就是蒸馏的一个最早期的雏形,其实在这个时候都还没有考虑用这个什么KL散度来做,就只是提出最简单的一个思想是什么,就是用MSE来做就够了。

我们一直即便到今天,我们做很多知识正溜的实验,我们依然会发现MIC可能有的时候都会比K要好。虽然大家都说自己用什么KL散度用什么JS散度,但是就是否现在就最优,还真不一定有的时候就是MSE效果好。

注:MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

需要注意的事

公式的推导基于两个假设:

1.T得足够大的(相比z的logits,即z_i)

2.模型输出的logic是零均值的(即均值为0),因为模型输出的logic是零均值的,这个z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0

 证明梯度等于输出值-标签y

softmax函数

归一化,使其输出的概率和为1

S_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

S_i代表的是第i个神经元的输出。

神经元的输出,一个神经元如下图:

z_i=\sum_jw_{ij}x_{ij}+b

其中w_{ij}是第i个神经元的第j个权重,b是偏移值。z_i表示该网络的第i个输出。

给这个输出加上一个softmax函数,得a_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

a_i代表softmax的第i个输出值

交叉熵损失函数 loss function

L=-\sum_i{y_i}{lna_i}

其中y_i表示真实的分类结果。

证明梯度等于输出值-标签y

loss对于神经元输出z_i的梯度为\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}

由于softmax公式的特性,它的分母包含了所有神经元的输出,对于不等于i的其他输出里面,也包含着z_i,所有的a都要纳入到计算范围中,并且后面的计算可以看到需要分为i=ji \ne j两种情况求导。

由于\frac{\partial (-\sum_{k\ne j}y_{k}ln a_k)}{\partial a_j}=0

\frac{\partial C}{\partial a_j}=\frac{\partial (-\sum_jy_jln a_j)}{\partial a_j}=-\sum_jy_j\frac{1}{a_J}

如果i=j

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_{k\ne i}e^{z_k}+e^{z_i}})}{\partial z_i}=\frac{\sum_ke^{z_k}e^{z_i}-(e^{z_i})^2}{\sum_k(e^{z_k})^2}
=(\frac{e^{z_i}}{\sum_ke^{z_k}})(1-\frac{e^{z_i}}{\sum_ke^{z_k}})=a_i(1-a_i)

这里\sum_ke^{z_k}=\sum_{k\ne i}e^{z_k}+e^{z_i}

如果i \ne j

这里\sum_ke^{z_k}=\sum_{k\ne j}e^{z_k}+e^{z_j}

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=-e^{z_j}(\frac{1}{\sum_ke^{z_k}})e^{z_i}=-a_ia_j

综上

\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}=(-\sum_jy_j\frac{1}{a_j})\frac{\partial a_j}{\partial z_i}=-\frac{y_i}{a_i}a_i(1-a_i)+\sum_{j\neq i}\frac{y_i}{a_j}a_ia_j
=-y_i+y_ia_i+\sum_{j\neq i}{y_ia_i}=-y_i+a_i\sum_{j}y_j

最后,针对分类问题,我们给定的结果y_i最终只会有一个类别是1,其他非标签类别都是0,因此,对于分类问题,这个梯度等于

\frac{\partial L}{\partial z_i}=a_i-y_i

知识蒸馏RocketQAv2

https://arxiv.org/pdf/2110.07367.pdf

这个模型有两部分组成一个retriever和一个ranker。这个做的事情就是说用label去监督re-ranker,然后用ranker去监督retriever。用KL散度去约束它约束,用这个K散路去让这个re-ranker的分布和retriever的分布对齐。

要注意就是说。这里就是他们就没有用MSE,就是说如果用MSE怎么做,就是说对应的这个直接相减,就对应位置直接相减,然后分MSE就行。这里用的是KL散度。

KL散度的定义,你可以认为是这样的,让这两个概率分别相除,除完了之后都要再取对数,然后再乘以这个概率。

DE,这个teacher model的概率乘以teacher model的概率乘以log,teacher model的概率除以student model的概率。然后把这么多概率给它都累加起来。

在这里,假定这里的是retriever给出来的一个概率分布假如说是十个候选,ranker也给了这样一个概率分布,那么就是十个概率分布对应的一项一项的去算这个KL度,即概率除概率,然后再取对数,然后再乘上ranker这个概率。

然后再把这十项给它累加起来,然后就是一个KL散度,这样的话,这个K散度其实是现在就是接受最多的一种损失函数。

因为KL散度就是天生的,可以捕获这个分布和分布之间的距离。像MSE缺点是什么?MSE的缺点是它没有整体的那种距离衡量的能力。MSE其实是对于细节的这种距离的衡量很强。如果MSE来的话,每一个每一项,这十项每一项的重要性对于MIC来说都是一样的。但是这个KL散度可能就会更在乎一个整体的一个分布上的一个区别了,就而不是说就在乎一些细节上的一些差别,因为有可能就是说。你某一些细节差距虽然大一些,但是你整体差距不大,所以KL散度也可以比较小。

实际上一切可以衡量两个分布之间距离的指标都可以用来做知识蒸馏,所以其实wasserstein距离也可以用来作为蒸馏的损失函数:

https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Wasserstein_Contrastive_Representation_Distillation_CVPR_2021_paper.pdf

为什么知识蒸馏会有效?

1. teacher model可以生成soft label,相比于原始数据的hard label,包含了更多信息量。

所以很多时候你与其说直接用一个数据集去训练一个模型,你还不如用这个数据集先训练一个大a模型比a模型要大的模型。再让大a模型去教会a模型去做,有可能效果就更好。就是因为大a模型这个teacher model可以生成soft label相比于原始数据的hard label,可以包含更多的信息量,从而就天然的有一种去燥的一种功能。

2. teacher model可以为大量的无标签数据打上label,然后为student提供一个大规模的训练集。然后从而可以给student提供一个更大尺度的训练集,然后防止student的一个过拟合,然后提高student model的一个泛化能力。也就是说,teacher model可以把自己的泛化能力交给student model

在这个知识蒸馏的过程当中,这也是为什么说很多大公司里边现在线上的模型都是蒸馏出来的小模型就是因为我们与其说直接训练小模型。还不如说就用这个蒸馏去蒸馏一个小模型反而泛化能力会更强一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/748705.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RTT——stm32f103的can总线通信

1.创建工程 2.配置时钟和引脚 引脚配置使能CAN 时钟配置,采用外部高速时钟 生成MDK工程后复制相关初始化函数到RTT-studio中 将void HAL_CAN_MspInit(CAN_HandleTypeDef* canHandle)函数复制至broad.c文件中 将时钟配置函数复制到drv_clk.c中,只复制函数…

C语言—打印如图矩阵

输出矩阵 在一个二维数组中形成并输出如下矩阵: #include <stdio.h> main() { int i,j,a[5][5];for(i0;i<4;i)for(j0;j<4;j)if(i<j) a[i][j]1;else a[i][j]i-j1;for(i0;i<4;i){ for(j0;j<4;j)printf("%d ",a[i][j]);printf("…

Xilinx FPGA模式配置

Xilinx FPGA模式配置 Xilinx UltraScale FPGA有7种配置模式&#xff0c;由模式输入引脚M[2:0]决定。七种模式如图1所示。 图1 7种配置模式 7种模式可分为3大类&#xff0c; 1、JTAG模式&#xff08;可归为从模式&#xff09;&#xff1b; 2、主模式&#xff1b; 3、从模式…

影响交易收益的因素有哪些?

在尝试做交易时&#xff0c;你可能会问自己一个问题&#xff1a;交易一天能赚多少钱&#xff1f;“如果我全职投入交易&#xff0c;一天能赚多少&#xff1f;”或者更广泛地说&#xff0c;“交易能为我带来怎样的财富&#xff1f;”这些问题本质上都充满了不确定性&#xff0c;…

Spring Cloud Alibaba微服务从入门到进阶(一)(SpringBoot三板斧、SpringBoot Actuator)

Springboot三板斧 1、加依赖 2、写注解 3、写配置 Spring Boot Actuator Spring Boot Actuator 是 Spring Boot 提供的一系列用于监控和管理应用程序的工具和服务。 SpringBoot导航端点 其中localhost:8080/actuator/health是健康检查端点&#xff0c;加上以下配置&#xf…

pytorch之诗词生成--2

先上代码: # -*- coding: utf-8 -*- # File : dataset.py # Author : AaronJny # Time : 2019/12/30 # Desc : 构建数据集 from collections import Counter import math import numpy as np import tensorflow as tf import settingsclass Tokenizer:""&…

邮件自动化:简化Workplace中的操作

电子邮件在职场中的使用对于企业和组织的日常活动起着重要的作用。电子邮件不再仅仅是一种通信方式&#xff0c;已经成为现代企业和组织实施日常运营的关键要素。 除了通信&#xff0c;电子邮件对于需求生成、流程工作流、交易审批以及各种其他与业务相关的活动至关重要。在当…

springboot高校门诊管理系统

摘 要 相比于以前的传统手工管理方式&#xff0c;智能化的管理方式可以大幅降低高校门诊的运营人员成本&#xff0c;实现了高校门诊管理的标准化、制度化、程序化的管理&#xff0c;有效地防止了高校门诊管理的随意管理&#xff0c;提高了信息的处理速度和精确度&#xff0c;能…

MySQL中的索引失效情况介绍

MySQL中的索引是提高查询性能的重要工具。然而&#xff0c;在某些情况下&#xff0c;索引可能无法发挥作用&#xff0c;甚至导致查询性能下降。在本教程中&#xff0c;我们将探讨MySQL中常见的索引失效情况&#xff0c;以及它们的特点和简单的例子。 1. **索引失效的情况** …

Linux:深入文件系统

一、Inode 我们使用ls -l的时候看到的除了看到文件名&#xff0c;还看到了文件元数据。 [rootlocalhost linux]# ls -l 总用量 12 -rwxr-xr-x. 1 root root 7438 "9月 13 14:56" a.out -rw-r--r--. 1 root root 654 "9月 13 14:56" test.c 每行包含7列&…

【JavaEE初阶系列】——多线程 之 创建线程

目录 &#x1f388;认识Thread类 &#x1f388;Sleep &#x1f388;创建线程 &#x1f6a9;继承Thread&#xff0c;重写run方法 &#x1f6a9;实现Runnable接口&#xff0c;重写run方法 &#x1f6a9;使用匿名内部类创建 Thread 子类对象 &#x1f6a9;使用匿名内部类&…

非空约束

oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 非空约束 所谓的非空约束&#xff0c;指的是表中的某一个字段的内容不允许为空。如果要使用非空约束&#xff0c;只需要在每个列的后面利用“NOT NULL”声明即可 -- 删除数…

【Preprocessing数据预处理】之Scaler

在机器学习中&#xff0c;特征缩放是训练模型前数据预处理阶段的一个关键步骤。不同的缩放器被用来规范化或标准化特征。这里简要概述了您提到的几种缩放器&#xff1a; StandardScaler StandardScaler 通过去除均值并缩放至单位方差来标准化特征。这种缩放器假设特征分布是正…

PFA烧杯透明聚四氟乙烯刻度量杯

PFA烧杯&#xff0c;刻度清晰&#xff0c;耐酸碱&#xff0c;和有机溶剂。

腾讯春招后端一面(八股篇)

前言 前几天在网上发了腾讯面试官问的一些问题&#xff0c;好多小伙伴关注&#xff0c;今天对这些问题写个具体答案&#xff0c;博主好久没看八股了&#xff0c;正好复习一下。 面试手撕了三道算法&#xff0c;这部分之后更&#xff0c;喜欢的小伙伴可以留意一下我的账号。 1…

ElementUI Message 消息提示,多个显示被覆盖的问题

现象截图&#xff1a; 代码&#xff1a;主要是在this.$message 方法外层加上 setTimeout 方法 <script> export default {name: "HelloWorld",props: {msg: String,},methods: {showMessage() {for (let i 0; i < 10; i) {setTimeout(() > {this.$mess…

《荒野大镖客》等优秀的国产游戏能成为国产3a的标杆吗

中国或许不需要3A&#xff0c;但对于一些玩家来说&#xff0c;国产3A更多的是一个梦想&#xff0c;就像动画爱好者期待的优秀国产2D动画一样。 提问者所说的“玩家众多”&#xff0c;其实非核心玩家占比很高。 其中有一些是《王者荣耀》、《和平精英》等轻手游玩家或者国内二次…

yolov8 分割 模型 网络 模块图

下图是使用yolov8n-seg-p6.yaml imgsz1472 类别数2的情况下训练得到的静态导出的onnx文件使用netron工具可视化的结果 简单标注了yolov8n-seg-p6.yaml配置文件中各层和netron工具可视化的结果的对应关系

图解缓存淘汰算法 LRU、LFU | 最近最少使用、最不经常使用算法 | go语言实现

写在前面 无论是什么系统&#xff0c;在研发的过程中不可避免的会使用到缓存&#xff0c;而缓存一般来说我们不会永久存储&#xff0c;但是缓存的内容是有限的&#xff0c;那么我们如何在有限的内存空间中&#xff0c;尽可能的保留有效的缓存信息呢&#xff1f; 那么我们就可以…

前端基础——HTML傻瓜式入门(2)

该文章Github地址&#xff1a;https://github.com/AntonyCheng/html-notes 在此介绍一下作者开源的SpringBoot项目初始化模板&#xff08;Github仓库地址&#xff1a;https://github.com/AntonyCheng/spring-boot-init-template & CSDN文章地址&#xff1a;https://blog.c…