【蒸馏】目标检测蒸馏的不完全整理和个人笔记

其实仔细想想模型蒸馏的监督信号无非来自原先损失函数(分类,bbox)或者是相关组件(backbone,FPN),在这里我不太想用传统的logit蒸馏和feature map蒸馏来表示上面两种蒸馏方式, 主要是现在的目标检测的蒸馏大多数是围绕相关组件和分类,对于bbox这一目标检测的重要组成部分的论文相对较少

(1)Improved Knowledge Distillation via Teacher Assistant

模型能力差距比较大时表现的原因:

1. 教师的表现提高了,因此它作为一个更好的预测者,为学生提供了更好的监督。

2. 老师变得如此复杂,以至于学生没有足够的能力或机制来模仿她的行为,尽管收到了提示。

3. 教师对数据的确定性增加,从而使其逻辑(软目标)不那么软。这削弱了通过软目标匹配实现的知识转移

(2)Localization Distillation for Object Detection

这篇是为数不多对Bbox进行蒸馏的论文!!!同时这一篇也是结合了feature map的蒸馏和logit的蒸馏,算是为logit蒸馏挽回一丝颜面

  • 主干结构中同时传递了定位和分类信息,监督特征图和模仿特征图有重要意义,
  • 但其中只有部分区域能使两个任务同时收益,此时引入VLR (Valuable Localization Region)概念来表示同时增益的区域。

图片来自知乎作者解读

按照我的理解,LD可以理解为:

  1. bbox使用DFL,此时模型学习的是左上角点或右下角点目标相对大小关系,具体地说对于一个模糊的坐标可能的目标范围在[y_{min},y_{max}]之间,通过采样16个(按照论文)点得到大致的坐标响应的情况(此时可能是单峰,多峰,...)
  2. 使用KL损失函数监督teacher和student的对于坐标的响应差异,即通过取logsoftmax或softmax转化为相对大小关系进行监督。

bbox损失使用的时DFL,因此对于教师信息,可以和分类损失一样使用KLloss

(3)探究了feature map,bbox logit和class logit之间的关系

同样来自《Localization Distillation for Object Detection》的作者扩展解读

其中分别代表了两种主流方案:一种是对feature map 进行蒸馏的 feature imitation,一种是对输出logit进行蒸馏的 logit mimicking(其中就包括bbox logit和class logit)

  • bbox logit和class logit的降低不一定带来feature map误差的降低

作者的观点是:

教师学出来的feature,本身就是一个大模型在更多epoch训练下得到的,它是一个high-level的具有抽象语义的特征。因此,强制学生学习这样的high-level特征未必是一件好事,它限制了学生特征所能学习的空间大小。

内容冗长但是非常值得一看,这里先挂一下结论:

(4)PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient

  • FPN的知识在不同检测器(GFL、FCOS和RetinaNet)上是可以传递的

表中上面两组实验证明可以传递,下面两组证明特征值的大小差异会干扰两个异质检测器之间的知识蒸馏(FCOS头部进行了组归一化操作) 。

  • FPN特征模仿可以成功地提取异构的检测器的知识

建议首先对教师和学生的特征进行归一化,使其均值和单位方差为零,并最小化归一化特征之间的MSE。此外,我们希望归一化遵循卷积属性——这样相同特征映射的不同元素,在不同的位置,以相同的方式归一化。

归一化操作:

    def norm(self, feat: torch.Tensor) -> torch.Tensor:"""Normalize the feature maps to have zero mean and unit variances.Args:feat (torch.Tensor): The original feature map with shape(N, C, H, W)."""assert len(feat.shape) == 4N, C, H, W = feat.shapefeat = feat.permute(1, 0, 2, 3).reshape(C, -1)mean = feat.mean(dim=-1, keepdim=True)std = feat.std(dim=-1, keepdim=True)feat = (feat - mean) / (std + 1e-6)return feat.reshape(C, N, H, W).permute(1, 0, 2, 3)

特征蒸馏:

            norm_S, norm_T = self.norm(pred_S), self.norm(pred_T)loss += F.mse_loss(norm_S, norm_T) / 2

按照论文的意思:首先进行特征归一化,然后计算MSE损失。从数学上讲,它相当于首先计算两个特征之间的Pearson相关系数(r)向量,然后使用1-r作为新的特征模仿损失。

(5)《Channel-wise Knowledge Distillation for Dense Prediction》

视角从原先的空间转到空间

对每个通道的激活图进行归一化处理,loss函数中的前传函数本质上还是最小化KL散度

    def forward(self, preds_S, preds_T):"""Forward computation.Args:preds_S (torch.Tensor): The student model prediction withshape (N, C, H, W).preds_T (torch.Tensor): The teacher model prediction withshape (N, C, H, W).Return:torch.Tensor: The calculated loss value."""assert preds_S.shape[-2:] == preds_T.shape[-2:]N, C, H, W = preds_S.shapesoftmax_pred_T = F.softmax(preds_T.view(-1, W * H) / self.tau, dim=1)logsoftmax = torch.nn.LogSoftmax(dim=1)loss = torch.sum(softmax_pred_T *logsoftmax(preds_T.view(-1, W * H) / self.tau) -softmax_pred_T *logsoftmax(preds_S.view(-1, W * H) / self.tau)) * (self.tau**2)loss = self.loss_weight * loss / (C * N)return loss

(6)《Decoupled Knowledge Distillation》

  • 将经典的Knowledge Distillation(KD)损失分解为两个部分:目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)

​建议推导可以看一下论文

  • 单独使用TCKD可能对学生没有帮助甚至有害,然而,NCKD的蒸馏性能与经典KD相当,甚至更好

因此作者的观点是:TCKD不应该单独使用

论文附录中也给TCKD导致性能下降做出了解释:

高温(T=4)会导致很大的梯度增加非目标类的logits,这会损害学生预测的正确性。如果没有NCKD,类的相似度(或者突出的暗知识)的信息是不可用的,所以TCKD的梯度不能起到很好的作用,反而会导致性能下降(因为TCKD可以在易拟合的训练数据上带来边际性能增益)。

  • TCKD传递了关于训练样本“难度”的知识。

基于上面的观点,上述实验同时也包含NCKD,分别设计了三个实验:

①数据增强后的CIFAR-100 :如果应用强增强,TCKD可以获得显着的性能增益

②噪声CIFAR-100 : TCKD在噪声较大的训练数据上取得了更多的绩效提升

③困难数据集Imagenet

  • NCKD是logit蒸馏的核心,如下式所示,NCKD受教师对训练样本的置信度的限制,更有信心的预测导致更小的NCKD权重。

这个同时作为解释更大模型蒸馏小模型表现不佳的解释:NCKD的抑制,随着置信度变大,权重变小,不再用模型容量这种模糊的概念来表达。

  •  论文中重新设计了两个超参数来克服置信度的限制,同时平衡TCKD和NCKD的重要性

教师越自信,NCKD越有价值,β值越大,但是过大又会损害模型梯度,最后使用目标类的logit(即zt,其中z表示输出logit, t表示目标类)与非目标类之间的最大logit之间的差距可以作为调优β的可靠指导,并发现其可能包含正相关

def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, size_average=False)* (temperature**2)/ target.shape[0])pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)* (temperature**2)/ target.shape[0])return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rt

(7)Masked Generative Distillation 

对于中间层的蒸馏,通过两层卷积恢复掩码部分

  self.generation = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True), nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))  def get_dis_loss(self, preds_S, preds_T):loss_mse = nn.MSELoss(reduction='sum')N, C, H, W = preds_T.shapedevice = preds_S.devicemat = torch.rand((N,1,H,W)).to(device)mat = torch.where(mat>1-self.lambda_mgd, 0, 1).to(device)masked_fea = torch.mul(preds_S, mat)new_fea = self.generation(masked_fea)dis_loss = loss_mse(new_fea, preds_T)/Nreturn dis_loss

(8)《Distilling Object Detectors with Fine-grained Feature Imitation》

"""
sup_feature :教师模型backbone出来之后的特征
stu_feature_adap:学生模型backbone出来之后和教师模型channel对齐之后的特征
mask_batch:存储每个图像中与锚框(anchors)具有重叠的区域的二进制掩码(mask)的列表,最后会将列表堆叠
"""
sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms

同时也有和yoloV5结合的工作:yolov5使用知识蒸馏 

(9)《Distilling Object Detectors with Task Adaptive Regularization》

(10)《General Instance Distillation for Object Detection》

  • 物体附近的特征区域,甚至来自背景区域的判别补丁也有意义的知识都有利于知识的提炼。提出两个衡量指标:GI score 和 GI box

将分类分数的L1距离计算为GI score ,选择分数较高的盒子作为GI box,使用NMS去除重复的区域

  • feature map蒸馏

根据每个GI盒的不同尺寸从匹配的FPN层中裁剪特征

  • 关系蒸馏

使用欧几里得距离来度量实例的相关性,使用L1距离来迁移知识

  • 使用基于响应的蒸馏来处理头部

由于没有开源,后面两个我也是一笔带过,可以去阅读原论文;虽然官方没有提供代码,可以参考LD中对基于特征图的复现。

(11)《IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS》

  • 使用空间和通道注意力作为模仿方式
  • 引入非局部蒸馏捕获全局关系

(12)《Distilling Object Detectors via Decoupled Features》

(12)《Focal and Global Knowledge Distillation for Detectors》

这一篇论文被诟病的点就在于和上面这篇思想基本上是相同的,只不过上面这篇工作使用NonLocal模块捕获全局关系,而这一篇使用GCBlock。

  • 使用空间注意力分割前景和后景,从而得到前景和后景损失
  • 同样是空间和通道注意力损失
  • 同样是池化捕获关系损失

fgd.py - yzd-v/FGD - GitHub1s

《Distilling Object Detectors with Feature Richness 》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/651601.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入浅出 diffusion(4):pytorch 实现简单 diffusion

1. 训练和采样流程 2. 无条件实现 import torch, time, os import numpy as np import torch.nn as nn import torch.optim as optim from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import DataLoader from torchvision.…

LayoutInflater.inflate全面解读

方法解析 LayoutInflater.inflate() 是 Android 系统中用于将 XML 布局文件转换成相应的 View 的方法。在 Android 开发中,我们经常使用此方法来动态创建和填充布局。 public View inflate(LayoutRes int resource, Nullable ViewGroup root, boolean attachToRoo…

LVGL v9学习笔记 | 12 - 弧形控件的使用方法(arc)

一、arc控件 arc控件的API在lvgl/src/widgets/arc/lv_arc.h 中声明,以lv_arc_xxx命名。 arc控件由背景圆弧和前景圆弧组成,前景圆弧的末端有一个旋钮,前景圆弧可以被触摸调节。 1. 创建arc对象 /*** Create an arc object* @param parent pointer to an object, it w…

Pyecharts 风采:从基础到高级,打造炫酷象形柱状图的完整指南【第40篇—python:象形柱状图】

文章目录 引言安装PyechartsPyecharts象形柱状图参数详解1. Bar 类的基本参数2. 自定义图表样式3. 添加标签和提示框 代码实战:绘制多种炫酷象形柱状图进阶技巧:动态数据更新与交互性1. 动态数据更新2. 交互性设计 拓展应用:结合其他图表类型…

深度学习-使用Labelimg数据标注

数据标注是计算机视觉和机器学习项目中至关重要的一步,而使用工具进行标注是提高效率的关键。本文介绍了LabelImg,一款常用的开源图像标注工具。用户可以在图像中方便而准确地标注目标区域,为训练机器学习模型提供高质量的标注数据。LabelImg…

Unity中URP下逐顶点光照

文章目录 前言一、之前额外灯逐像素光照的数据准备好后,还有最后的处理二、额外灯的逐顶点光照1、逐顶点额外灯的光照颜色2、inputData.vertexLighting3、surfaceData.albedo 前言 在上篇文章中,我们分析了Unity中URP下额外灯,逐像素光照中聚…

vue3 codemirror关于 sql 和 json格式化的使用以及深入了解codemirror 使用json格式化提示错误的关键代码

文章目录 需求说明0、安装1. 导入js脚本2.配置3.html处使用4.js处理数据(1)json格式化处理(2)sql格式化处理 5. 解决问题1:json格式化错误提示报错(1)打开官网(2)打开官网&#xff0…

【机器学习笔记】1 线性回归

回归的概念 二分类问题可以用1和0来表示 线性回归(Linear Regression)的概念 是一种通过属性的线性组合来进行预测的线性模型,其目的是找到一条直线或者一个平面或者更高维的超平面,使得预测值与真实值之间的误差最小化&#x…

ppt背景图片怎么设置?让你的演示更加出彩!

PowerPoint是一款广泛应用于演示文稿制作的软件,而背景图片是演示文稿中不可或缺的一部分。一个好的背景图片能够提升演示文稿的整体效果,使观众更加关注你的演示内容。可是ppt背景图片怎么设置呢?本文将介绍ppt背景图片设置的三个方法&#…

数据库 sql select *from account where name=‘张三‘ 执行过程

select *from account where name张三分析上面语句的执行过程 用到了索引 由于是根据 1.name字段进行查询,所以先根据name张三’到name字段的二级索引中进行匹配查 找。但是在二级索引中只能查找到 Arm 对应的主键值 10。 2.由于查询返回的数据是*&#xff0c…

5.Hive表修改Location,一次讲明白

Hive表修改Loction 一、Hive中修改Location语句二、方案1 删表重建1. 创建表,写错误的Location2. 查看Location3. 删表4. 创建表,写正确的Location5. 查看Location 三、方案2 直接修改Location并恢复数据1.建表,指定错误的Location&#xff0…

【CSS】实现鼠标悬停图片放大的几种方法

1.背景图片放大 使用css设置背景图片大小100%&#xff0c;同时设置位置和过渡效果&#xff0c;然后使用&#xff1a;hover设置当鼠标悬停时修改图片大小&#xff0c;实现悬停放大效果。 <!DOCTYPE html> <html lang"en"> <head><meta charset…

###C语言程序设计-----C语言学习(4)#

前言&#xff1a;感谢老铁的浏览&#xff0c;希望老铁可以一键三连加个关注&#xff0c;您的支持和鼓励是我前进的动力&#xff0c;后续会分享更多学习编程的内容。现在开始今天的内容&#xff1a; 一. 主干知识的学习 1.字符型数据 &#xff08;1&#xff09;字符型常量 字…

Leetcode541反转字符串Ⅱ(java实现)

我们今天分享的题目是字符串反转的进阶版反转字符串Ⅱ。 我们首先来看题目描述&#xff1a; 乍一看题目&#xff0c;有种懵逼的感觉&#xff0c;不要慌&#xff0c;博主来带着你分析题目&#xff0c;题目要求&#xff1a; 1. 每隔2k个字符&#xff0c;就对2k字符中的前k个字符…

C++设计模式介绍:优雅编程的艺术

物以类聚 人以群分 文章目录 简介为什么有设计模式&#xff1f; 设计模式七大原则单一职责原则&#xff08;Single Responsibility Principle - SRP&#xff09;开放封闭原则&#xff08;Open/Closed Principle - OCP&#xff09;里氏替换原则&#xff08;Liskov Substitution …

MongoDB:从容器使用到 Mongosh、Python/Node.js 数据操作

文章目录 1. 容器与应用之间的关系介绍2. 使用 Docker 容器安装 MongoDB3. Mongosh 操作3.1 Mongosh 连接到 MongoDB3.2 基础操作与 CRUD 4. Python 操作 MongoDB5. Nodejs 操作 MongoDB参考文献 1. 容器与应用之间的关系介绍 MongoDB 的安装有时候并不是那么容易的&#xff0…

《HelloGitHub》第 94 期

兴趣是最好的老师&#xff0c;HelloGitHub 让你对编程感兴趣&#xff01; 简介 HelloGitHub 分享 GitHub 上有趣、入门级的开源项目。 https://github.com/521xueweihan/HelloGitHub 这里有实战项目、入门教程、黑科技、开源书籍、大厂开源项目等&#xff0c;涵盖多种编程语言 …

Redis6基础知识梳理~

初识NOSQL&#xff1a; NOSQL是为了解决性能问题而产生的技术&#xff0c;在最初&#xff0c;我们都是使用单体服务器架构&#xff0c;如下所示&#xff1a; 随着用户访问量大幅度提升&#xff0c;同时产生了大量的用户数据&#xff0c;单体服务器架构面对着巨大的压力 NOSQL解…

openssl3.2 - 测试程序的学习 - test\acvp_test.c

文章目录 openssl3.2 - 测试程序的学习 - test\acvp_test.c概述笔记要单步学习的测试函数备注END openssl3.2 - 测试程序的学习 - test\acvp_test.c 概述 openssl3.2 - 测试程序的学习 将test*.c 收集起来后, 就不准备看makefile和make test的日志参考了. 按照收集的.c, 按照…

换个思维方式快速上手UML和 plantUML——类图

和大多数朋友一样&#xff0c;Jeffrey 在一开始的时候也十分的厌烦软件工程的一系列东西&#xff0c;对工程化工具十分厌恶&#xff0c;觉得它繁琐&#xff0c;需要记忆很多没有意思的东西。 但是之所以&#xff0c;肯定有是因为。对工程化工具的不理解和不认可主要是基于两个逻…