知识蒸馏详解及pytorch官网demo案例

知识蒸馏Knowledge Distillation(KD)

1、简介

一种模型压缩方法

知识蒸馏的一般框架(如下图)
三部分:知识、蒸馏算法、师生架构。
知识蒸馏的师生架构

知识

将知识分为三种形式:基于响应的(response-based)、基于特征的(feature-based)、基于关系的(relation-based)。
在这里插入图片描述

①基于响应的知识(response-based)【常用】
学习的知识是教师模型最后一个输出层logits。由于logits实际上是类别概率分布,因此基于响应的知识蒸馏限制在监督学习
在这里插入图片描述

最流行的基于响应的图像分类知识被称为软目标(soft target)

基于响应的知识蒸馏具体架构如下图。后面具体介绍该类知识蒸馏。
在这里插入图片描述
②基于特征的知识(feature-based)
学习的知识是教师模型中间层的基于特征的知识。下图为基于特征的知识蒸馏模型的通常架构。
在这里插入图片描述

③基于关系的知识(relation-based)
基于响应和基于特征的知识都使用了教师模型中特定层的输出,基于关系的知识进一步探索了不同层或数据样本的关系。下图为实例关系的知识蒸馏架构。

在这里插入图片描述

蒸馏机制

根据教师模型是否与学生模型同时更新,知识蒸馏的学习方案可分为离线(offline)蒸馏、在线(online)蒸馏、自蒸馏(self-distillation)

离线蒸馏(常用)
在离线蒸馏中,学生模型仅使用知识进行训练,而不与教师模型同时更新。学生模型独立地使用知识进行训练,目标是使学生模型的输出尽可能接近教师模型的输出。
大多数之前的知识蒸馏方法都是离线的。最初的知识蒸馏中,知识从预训练的教师模型转移到学生模型中,整个训练过程包括两个阶段:1)大型教师模型蒸馏前在训练样本训练;2)教师模型以logits(基于响应,生成软目标(soft target))或中间特征(基于特征)的形式提取知识,将其在蒸馏过程中指导学生模型的训练。

在线蒸馏
在线蒸馏时,教师模型和学生模型同步更新,而整个知识蒸馏框架都是端到端可训练的。
在线蒸馏是一种具有高效并行计算的单阶段端到端训练方案。然而,现有的在线方法(如相互学习)通常无法解决在线环境中的高容量教师,这使进一步探索在线环境中教师和学生模式之间的关系成为一个有趣的话题。

自蒸馏
在自蒸馏中,教师和学生模型使用相同的网络,这可以看作是在线蒸馏的一个特例。
在这里插入图片描述
从人类师生学习的角度可以直观地理解离线、在线和自蒸馏。
离线蒸馏是指知识渊博的教师教授学生知识;
在线蒸馏是指教师和学生一起学习;
自我蒸馏是指学生自己学习知识。

师生架构

教师模型(cumbersome model):已经训练好的,较为笨重的模型。
学生模型:通过蒸馏,将教师模型中已经学习到的知识迁移到的新的轻量级的模型。


2、学生模型的训练(基于响应的离线知识蒸馏)

hard target(硬目标)与 soft target(软目标)

hard target仅包含正样本信息
soft target具有更多信息,不仅包含正样本信息,还有相似负样本信息,比如左图的正样本标签为2,但由于写法与3相像,因此对标签3也给予一定的关注通过增大概率值;而右图的正样本标签2写法与7相像,因此对标签7也给予一定的关注。
具体到代码中就是加入蒸馏温度T。

在这里插入图片描述

蒸馏温度 T T T

原来的softmax 将多分类的输出结果映射为概率值。 q i = e z i ∑ j = 1 n e z j q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}} qi=j=1nezjezi,其中 z i z_i zi是模型的softmax层输出logits。

在进行知识蒸馏时,如果将教师模型的softmax输出,作为学生模型的 s o f t − t a r g e t soft-target softtarget,那么负标签的值接近于0,对学生模型的损失函数贡献非常小,使得模型难以利用教师模型学到的知识。因此,提出蒸馏温度T的概念,使得softmax是输出更加平滑。

加入蒸馏温度 T T T后的softmax
q i = e ( z i / T ) ∑ j = 1 n e ( z j / T ) q_i=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}} qi=j=1ne(zj/T)e(zi/T)

实验:当温度 T T T越高时,负标签的概率值的变化。

在这里插入图片描述正标签为第1个元素,当温度 T T T越高时,负标签的概率值相对被放得越大。在训练时,由于损失函数的惩罚,模型需要对负标签给予一定的关注;从而达到在学习老师模型时,一次训练不仅仅可以学到正样本的特征,也可以学到相似负样本的特征。

import numpy as npdef softmax(x):x_exp = np.exp(x)return x_exp/x_exp.sum()def softmax_t(x, T):# T是蒸馏温度x_exp = np.exp(x/T)return x_exp/x_exp.sum()output = np.array([5, 1.3, 2])print('temperature is 5: ', softmax_t(output, 5))
print('temperature is 10: ', softmax_t(output, 10))
print('temperature is 100: ', softmax_t(output, 100))

在这里插入图片描述

知识蒸馏训练的具体步骤

①训练好Teacher模型
②利用高温 T h i g h T_{high} Thigh产生 s o f t − t a r g e t soft-target softtarget
③使用{ s o f t − t a r g e t , T h i g h soft-target, T_{high} softtarget,Thigh}和{ h a r d − t a r g e t , T = 1 hard-target, T=1 hardtarget,T=1},同时训练 Student 模型
④设置蒸馏温度 T = 1 T=1 T=1,Student模型线上做推理

高温蒸馏过程的损失函数

学生损失函数student loss即, L h a r d = − ∑ j = 1 n l j l o g ( q j ) , q i = e z i ∑ j = 1 n e z j L_{hard}=-\sum_{j=1}^nl_jlog(q_j),q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}} Lhard=j=1nljlog(qj)qi=j=1nezjezi
蒸馏损失函数distillation loss即, L s o f t = − ∑ j = 1 n p j T l o g ( q j T ) , p i T = e ( v i / T ) ∑ j = 1 n e ( v j / T ) , q i T = e ( z i / T ) ∑ j = 1 n e ( z j / T ) L_{soft}=-\sum_{j=1}^np_j^Tlog(q_j^T),p_i^T=\frac{e^{(v_i/T)}}{\sum_{j=1}^n{e^{(v_j/T)}}},q_i^T=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}} Lsoft=j=1npjTlog(qjT)piT=j=1ne(vj/T)e(vi/T)qiT=j=1ne(zj/T)e(zi/T)

高温蒸馏过程的损失函数定义为: L = α L s o f t + β L h a r d L=\alpha L_{soft}+\beta L_{hard} L=αLsoft+βLhard
其中, l i l_i li为第i个ground truth值, z i z_i zi为学生模型的第i个输出logits值, v i v_i vi为老师模型的第i个输出logits值, α \alpha α β \beta β为超参数。

在这里插入图片描述

pytorch官网 知识蒸馏demo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/783037.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字乡村发展蓝图:科技赋能农村实现全面振兴

目录 一、数字乡村发展蓝图的内涵与目标 二、科技赋能农村:数字乡村发展的动力与路径 (一)加强农业科技创新,提升农业生产效率 (二)推进农村电商发展,拓宽农民增收渠道 (三&…

【华为OD机试C++】质数因子

《最新华为OD机试题目带答案解析》:最新华为OD机试题目带答案解析,语言包括C、C++、Python、Java、JavaScript等。订阅专栏,获取专栏内所有文章阅读权限,持续同步更新! 文章目录 描述输入描述输出描述示例输出代码描述 功能:输入一个正整数,按照从小到大的顺序输出它的所…

MHA高可用配置与故障切换

前言: MHA高可用故障就是单点故障,那么我们如何解决单点故障MHA中Master如何将故障的机器停止,使用备用的Slave服务器 一 MHA定义 MHA(MasterHigh Availablity)是一套优秀的Mysql高可用环境下故障切换和主从复制的…

Queue的多线程爬虫和multiprocessing多进程

Queue的模块里面提供了同步的、线程安全的队列类,包括FIFO(先入后出)队列Queue、FIFO(后入先出)LifoQueue和优先队列PriorityQueue。(在上个文件创建了爬取文件)我们使用这个方法来获取&#xf…

【Linux】进程程序替换 做一个简易的shell

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 文章目录 前言 进程程序替换 替换原理 先看代码和现象 替换函数 第一个execl(): 第二个execv(): 第三个execvp(): 第四个execvpe()&a…

编程语言|C语言——C语言操作符的详细解释

这篇文章主要详细介绍了C语言的操作符,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧 一、基础 1.1 算数操作符 - * / % - * / 这些操作符是我们…

QT初识(1)

QT初识 桌面开发什么是QT下载QT安装好之后的工具AssisantDesignerQT Creator 创建一个简单的项目 我们今天来认识一下QT。 桌面开发 在了解QT,我们得了解一下桌面开发: 桌面开发指的是编写和构建在个人计算机或其他桌面操作系统(如Windows、…

关系网络c++

题目&#xff1a; 代码&#xff1a; #include<bits/stdc.h>using namespace std;int n,x,y;struct node{int num;//编号 int t;//步数 node(){}node(int sum,int tt){numsum;ttt;} }; int mp[101][101];//图 bool flag[101];//标记 queue<node> q; void bfs() {q…

鸿蒙(HarmonyOS)ArkTs语言基础教程(大纲)

鸿蒙&#xff08;HarmonyOS&#xff09;ArkTs语言基础教程 简介 ArkTS 是鸿蒙生态的应用开发语言。它在保持 TypeScript&#xff08;简称 TS&#xff09;基本语法风格的基础上&#xff0c;对 TS 的动态类型特性施加更严格的约束&#xff0c;引入静态类型。同时&#xff0c;提…

Spring面试题系列-5

Spring框架是由于软件开发的复杂性而创建的。Spring使用的是基本的JavaBean来完成以前只可能由EJB完成的事情。然而&#xff0c;Spring的用途不仅仅限于服务器端的开发。从简单性、可测试性和松耦合性角度而言&#xff0c;绝大部分Java应用都可以从Spring中受益。 为什么要使用…

Python教程:(Sweetviz)仅三行代码就能实现数据可视化

weetviz 的主要功能 自动化报告生成 Sweetviz 能够自动分析数据集的特征和属性&#xff0c;并生成详细的 EDA 报告。用户无需手动编写复杂的代码&#xff0c;只需简单调用 Sweetviz 函数即可生成完整的报告。 多种可视化图表 Sweetviz 提供了多种可视化图表&#xff0c;包括直…

Hive常用函数_16个时间日期处理

在Hive中&#xff0c;常用的时间处理函数包括但不限于以下几种&#xff1a; 1. current_date(): 返回当前日期&#xff0c;不包含时间部分 SELECT current_date(); -- Output: 2024-09-152. current_timestamp(): 返回当前时间戳&#xff0c;包含日期和时间部分 SELECT curr…

【Docker】Windows中打包dockerfile镜像导入到Linux

【Docker】Windows中打包dockerfile镜像导入到Linux 大家好 我是寸铁&#x1f44a; 总结了一篇【Docker】Windows中打包dockerfile镜像导入到Linux✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 今天遇到一个新需求&#xff0c;如何将Windows中打包好的dockerfile镜像给迁移…

Autodesk Maya 2025---智能建模与动画创新,重塑创意工作流程

Autodesk Maya 2025是一款顶尖的三维动画软件&#xff0c;广泛应用于影视广告、角色动画、电影特技等领域。新版本在功能上进行了全面升级&#xff0c;新增了对Apple芯片的支持&#xff0c;建模、绑定和角色动画等方面的功能也更加出色。 在功能特色方面&#xff0c;Maya 2025…

equals()和hashcode()的区别【大白话Java面试题】

equals()和hashcode()的区别 大白话 1.equals()&#xff1a;反应的是对象或变量具体的值&#xff0c;及两个对象包含的具体的值&#xff08;可能是对象的引用&#xff0c;也可能是值类型的值&#xff09; 2.hashcode():计算两个对象的哈希值&#xff0c;并返回哈希码&#xff…

逆向分析之antibot

现在太卷了&#xff0c;没资源&#xff0c;很难接到好活&#xff0c;今天群里看到个单子&#xff0c;分析了下能做&#xff0c;结果忙活了一小会&#xff0c;幸好问了下&#xff0c;人家同时有多个人再做&#xff0c;直接就拒绝再继续了。就这次忘了收定金了&#xff0c;所以原…

数据可视化之多表显示

多表显示subplot(),subplots() # 使用 pyplot 中的 subplot() 和 subplots() 方法来绘制多个子图# 导入库&#xff0c;和调用中文import matplotlib.pyplot as pltimport numpy as np# 作用&#xff1a;解决坐标轴为负时 负号显示为方框的问题# axes&#xff1a;坐标轴# Unicod…

使用python实现i茅台自动预约

使用python实现i茅台自动预约[仅限于学习,不可商用] 运行: 直接运行 imtApi.py 打包:切换到imt脚本目录,执行打包命令: pyinstaller --onefile imtApi.py这个应用程序可以帮助你进行茅台自动化配置。以下是一些使用说明: 平台注册账号(可用i茅台)不用登录,你可以进行…

Linux的VirtualBox中USB设备无法选择USB3.0怎么办?

在VirtualBox中&#xff0c;如果遇到USB设备无法选择 USB 3.0 的问题&#xff0c;可以尝试按照以下步骤来解决&#xff1a; 确保VirtualBox版本支持USB 3.0&#xff1a;首先&#xff0c;你需要确认你的VirtualBox版本是否支持USB 3.0。一些较旧的版本可能不支持&#xff0c;因此…

一篇搞定AVL树+旋转【附图详解旋转思想】

&#x1f389;个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名乐于分享在学习道路上收获的大二在校生 &#x1f648;个人主页&#x1f389;&#xff1a;GOTXX &#x1f43c;个人WeChat&#xff1a;ILXOXVJE &#x1f43c;本文由GOTXX原创&#xff0c;首发CSDN&…