知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206637.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【STM32】TIM定时器输出比较

1 输出比较 1.1 输出比较简介 OC(Output Compare)输出比较;IC(Input Capture)输入捕获;CC(Capture/Compare)输入捕获和输出比较的单元输出比较可以通过比较CNT与CCR寄存器值&#…

JavaWeb-HTTP协议

1. 什么是HTTP协议 HTTP超文本传输协(Hyper Text transfer protocol),是一种用于用于分布式、协作式和超媒体信息系统的应用层协议。它于1990年提出,经过十几年的使用与发展,得到不断地完善和扩展。HTTP 是为 Web 浏览器与 Web 服务器之间的…

AI自动生成代码工具

AI自动生成代码工具是一种利用人工智能技术来辅助或自动化软件开发过程中的编码任务的工具。这些工具使用机器学习和自然语言处理等技术,根据开发者的需求生成相应的源代码。以下是一些常见的AI自动生成代码工具,希望对大家有所帮助。北京木奇移动技术有…

HCIP —— BGP 基础 (上)

BGP --- 边界网关协议 (路径矢量协议) IGP --- 内部网关协议 --- OSPF RIP ISIS EGP --- 外部网关协议 --- EGP BGP AS --- 自治系统 由单一的组织或者机构独立维护的网络设备以及网络资源的集合。 因 网络范围太大 需 自治 。 为区分不同的AS&#…

vim常见操作

vim常见操作 文章目录 vim常见操作1. 回退/前进2. 搜索3. 删除4. 定位到50行5. 显示行号6. 复制粘贴7. 剪贴8. 替换9. vim打开文件的时候出现 1. 回退/前进 1.esc进入命令模式 2.ctrlr 前进 u 回退2. 搜索 1) esc进入命令模式 2) /text  查找text&am…

【STM32】TIM定时器输入捕获

1 输入捕获 1.1 输入捕获简介 IC(Input Capture)输入捕获 输入捕获模式下,当通道输入引脚出现指定电平跳变时(上升沿/下降沿),当前CNT的值将被锁存到CCR中(把CNT的值读出来,写入到…

类风湿性关节炎口腔黏膜破裂引发抗瓜氨酸细菌和人蛋白抗体反应

今天给同学们分享一篇实验文章“Oral mucosal breaks trigger anti-citrullinated bacterial and human protein antibody responses in rheumatoid arthritis”,这篇文章发表在Sci Transl Med期刊上,影响因子为17.1。 结果解读: 口腔黏膜破…

Redis主从复制的配置和实现原理

Redis的持久化功能在一定程度上保证了数据的安全性,即便是服务器宕机的情况下,也可以保证数据的丢失非常少。通常,为了避免服务的单点故障,会把数据复制到多个副本放在不同的服务器上,且这些拥有数据副本的服务器可以用…

如何快速构建知识服务平台,打造个人或企业私域流量

随着互联网的快速发展,传统的知识付费平台已经不能满足用户的需求。而SaaS知识付费小程序平台则是一种新型的知识付费方式,具有灵活、便捷、高效等特点,为用户提供了更加优质的付费知识服务。本文将介绍如何搭建自己的SaaS知识付费小程序平台…

如何掌握构建 LMS 网站的艺术

目录 什么是学习管理系统 (LMS) 在线课程和 LMS 网站的好处 为什么 WordPress 对于 LMS 网站很重要 统一学习中心 多功能性和可扩展性 提高教育参与度 简化管理和监控 节省时间和费用 技能评估和绩效监督 持续学习和技能提升 使用 WordPress 插件构建成功的 LMS 课程 专注于您的…

sparkc程序idea调试提示内存不足

报错如下: Exception in thread "main" java.lang.IllegalArgumentException: System memory 259522560 must be at least 471859200. Please increase heap size using the --driver-memory option or spark.driver.memory in Spark configuration. 测…

自动驾驶:传感器初始标定

手眼标定 机器人手眼标定AxxB(eye to hand和eye in hand)及平面九点法标定 Ax xB问题求解,旋转和平移分步求解法 手眼标定AXXB求解方法(文献总结) 基于靶的方法 相机标定 (1) ApriTag (2) 棋盘格:cv::f…

富时中国A50指数暴跌

近年来,中国股市的波动一直备受关注,而富时中国A50指数更是其中一项备受瞩目的指标之一。然而,近期却出现了一场引人瞩目的暴跌,引发了广泛的关注和讨论。 富时中国A50指数简介 富时中国A50指数,作为富时罗素指数系列…

全新UI彩虹外链网盘系统源码V5.5/支持批量封禁+优化加载速度+用户系统与分块上传

源码简介: 全新UI彩虹外链网盘系统源码V5.5,它可以支持批量封禁优化加载速度。新增用户系统与分块上传。 彩虹外链网盘,作为一款PHP网盘与外链分享程序,具备广泛的文件格式支持能力。它不仅能够实现各种格式文件的上传&#xff…

CLASS60 DM蓝牙5.2双模热插拔PCB

键盘使用说明索引(均为出厂默认值) 软件支持(驱动的详细使用帮助)一些常见问题解答(FAQ)首次使用步骤蓝牙配对规则(重要)蓝牙和USB切换键盘默认层默认触发层0的FN键配置的功能默认功…

使用word中的VBA 批量设置Word中所有图片大小

在VBA编辑器中,你可以创建、编辑和运行VBA宏代码,以实现自动化任务和自定义Word 功能。如果你是VBA编程初学者,可以在VBA编辑器中查看Word VBA宏代码示例,以便更好地了解如何使用VBA编写代码。 要打开VBA编辑器,你可以…

【Vue】修改组件样式并动态添加样式

文章目录 目标修改样式动态添加/删除样式样式不生效 注意:类似效果el-step也可以实现,可以不用手动实现。这里只是练习。 目标 使用组件库中的组件,修改它的样式并动态添加/删除样式。 修改样式 组件中的一些类可能添加样式无法生效。如Ele…

[java学习日记]反射、动态代理

目录 一.反射的简单解释与获取字节码文件对象 二.获取构造方法对象Constructor 三.反射获取字节码文件中的成员变量Field 四.反射获取字节码文件中的成员方法:Method 五.反射练习:保存信息 六.反射练习:利用配置文件(存储类名…

第21章:网络通信

21.1 网络程序设计基础 21.1.1 局域网与互联网 为了实现两台计算机的通信,必须用一个网络线路连接两台计算机。如下图所示 21.1.2 网络协议 1.IP协议 IP是Internet Protocol的简称,是一种网络协议。Internet 网络采用的协议是TCP/IP协议&#xff0…

Google Bard vs. ChatGPT 4.0:文献检索、文献推荐功能对比

在这篇博客中,我们将探讨和比较四个不同的人工智能模型——ChatGPT 3.5、ChatGPT 4.0、ChatGPT 4.0插件和Google Bard。我们将通过三个问题的测试结果来评估它们在处理特定任务时的效能和响应速度。 导航 问题 1: 统计自Vehicle Routing Problem (VRP)第一篇文章发…