知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206637.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VBA语法结构及编程思想

VBA(Visual Basic for Applications)是一种编程语言,它被用于Microsoft Office应用程序的自动化,允许用户编写宏来执行常规任务。VBA是基于Microsoft的Visual Basic语言,但专为Office应用程序定制。 VBA语法格式 VBA的…

【STM32】TIM定时器输出比较

1 输出比较 1.1 输出比较简介 OC(Output Compare)输出比较;IC(Input Capture)输入捕获;CC(Capture/Compare)输入捕获和输出比较的单元输出比较可以通过比较CNT与CCR寄存器值&#…

JavaWeb-HTTP协议

1. 什么是HTTP协议 HTTP超文本传输协(Hyper Text transfer protocol),是一种用于用于分布式、协作式和超媒体信息系统的应用层协议。它于1990年提出,经过十几年的使用与发展,得到不断地完善和扩展。HTTP 是为 Web 浏览器与 Web 服务器之间的…

AI自动生成代码工具

AI自动生成代码工具是一种利用人工智能技术来辅助或自动化软件开发过程中的编码任务的工具。这些工具使用机器学习和自然语言处理等技术,根据开发者的需求生成相应的源代码。以下是一些常见的AI自动生成代码工具,希望对大家有所帮助。北京木奇移动技术有…

Redisson的基本使用

Redisson官网描述:Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格客户端(In-Memory Data Grid)。它不仅提供了一系列的 redis 常用数据结构命令服务,还提供了许多分布式服务,例如分布式锁、分布式对象、…

HCIP —— BGP 基础 (上)

BGP --- 边界网关协议 (路径矢量协议) IGP --- 内部网关协议 --- OSPF RIP ISIS EGP --- 外部网关协议 --- EGP BGP AS --- 自治系统 由单一的组织或者机构独立维护的网络设备以及网络资源的集合。 因 网络范围太大 需 自治 。 为区分不同的AS&#…

vim常见操作

vim常见操作 文章目录 vim常见操作1. 回退/前进2. 搜索3. 删除4. 定位到50行5. 显示行号6. 复制粘贴7. 剪贴8. 替换9. vim打开文件的时候出现 1. 回退/前进 1.esc进入命令模式 2.ctrlr 前进 u 回退2. 搜索 1) esc进入命令模式 2) /text  查找text&am…

Docker load 命令

docker load :导入使用docker save命令导出的镜像。 语法 docker load [OPTIONS]OPTIONS 说明: --input , -i :指定导入的文件,代替STDIN。 --quiet , -q :精简输出信息。 实例: 导入镜像&#xff1a…

【STM32】TIM定时器输入捕获

1 输入捕获 1.1 输入捕获简介 IC(Input Capture)输入捕获 输入捕获模式下,当通道输入引脚出现指定电平跳变时(上升沿/下降沿),当前CNT的值将被锁存到CCR中(把CNT的值读出来,写入到…

ubuntu16.04安装ROS+Gazebo

ubuntu16.04安装ROS参考文章 ros安装(一键最简安装,吹爆鱼香ROS,请叫我鱼吹) ROS篇——Ubuntu快速一键安装ROS或ROS2(通用) ubuntu安装ROS melodic(最新、超详细图文教程) 配置ubuntu以及安装ros2必要环…

类风湿性关节炎口腔黏膜破裂引发抗瓜氨酸细菌和人蛋白抗体反应

今天给同学们分享一篇实验文章“Oral mucosal breaks trigger anti-citrullinated bacterial and human protein antibody responses in rheumatoid arthritis”,这篇文章发表在Sci Transl Med期刊上,影响因子为17.1。 结果解读: 口腔黏膜破…

Redis主从复制的配置和实现原理

Redis的持久化功能在一定程度上保证了数据的安全性,即便是服务器宕机的情况下,也可以保证数据的丢失非常少。通常,为了避免服务的单点故障,会把数据复制到多个副本放在不同的服务器上,且这些拥有数据副本的服务器可以用…

如何快速构建知识服务平台,打造个人或企业私域流量

随着互联网的快速发展,传统的知识付费平台已经不能满足用户的需求。而SaaS知识付费小程序平台则是一种新型的知识付费方式,具有灵活、便捷、高效等特点,为用户提供了更加优质的付费知识服务。本文将介绍如何搭建自己的SaaS知识付费小程序平台…

如何掌握构建 LMS 网站的艺术

目录 什么是学习管理系统 (LMS) 在线课程和 LMS 网站的好处 为什么 WordPress 对于 LMS 网站很重要 统一学习中心 多功能性和可扩展性 提高教育参与度 简化管理和监控 节省时间和费用 技能评估和绩效监督 持续学习和技能提升 使用 WordPress 插件构建成功的 LMS 课程 专注于您的…

sparkc程序idea调试提示内存不足

报错如下: Exception in thread "main" java.lang.IllegalArgumentException: System memory 259522560 must be at least 471859200. Please increase heap size using the --driver-memory option or spark.driver.memory in Spark configuration. 测…

自动驾驶:传感器初始标定

手眼标定 机器人手眼标定AxxB(eye to hand和eye in hand)及平面九点法标定 Ax xB问题求解,旋转和平移分步求解法 手眼标定AXXB求解方法(文献总结) 基于靶的方法 相机标定 (1) ApriTag (2) 棋盘格:cv::f…

富时中国A50指数暴跌

近年来,中国股市的波动一直备受关注,而富时中国A50指数更是其中一项备受瞩目的指标之一。然而,近期却出现了一场引人瞩目的暴跌,引发了广泛的关注和讨论。 富时中国A50指数简介 富时中国A50指数,作为富时罗素指数系列…

【C/PTA】结构体专项练习

本文结合PTA专项练习带领读者掌握结构体,刷题为主注释为辅,在代码中理解思路,其它不做过多叙述。 目录 6-1 选队长6-2 按等级统计学生成绩6-3 学生成绩比高低6-4 综合成绩6-5 利用“选择排序算法“对结构体数组进行排序6-6 结构体的最值6-7 复…

香港商标注册申请所需资料及办理流程

作为东方明珠,自由港香港是世界上较自由的贸易通商口岸,再加上本身良好的基础设施和健全的法律制度,这给企业家提供了得天独厚的营商环境。在香港注册商标,可以迅速提高企业的知名度,提升企业不断成长的竞争力&#xf…

全新UI彩虹外链网盘系统源码V5.5/支持批量封禁+优化加载速度+用户系统与分块上传

源码简介: 全新UI彩虹外链网盘系统源码V5.5,它可以支持批量封禁优化加载速度。新增用户系统与分块上传。 彩虹外链网盘,作为一款PHP网盘与外链分享程序,具备广泛的文件格式支持能力。它不仅能够实现各种格式文件的上传&#xff…