【知识蒸馏】多任务模型 feature-based 知识蒸馏实战

一、实现流程

(1)定义学生和教师模型
(2)定义特征蒸馏损失

  • Mimic Loss
  • CWD Loss
  • MGD Loss
  • Feature Loss

(3)使用hook获取需要蒸馏的特征层

  • 定义回调函数
  • 使用hook函数
  • 获取需要蒸馏的挺特征层

(4)计算特征蒸馏损失
(5)计算总损失,反向传播

  • 计算总损失
  • 反向传播

(6)保存蒸馏模型

  • 移除hook
  • 保存蒸馏模型

二、代码实现

(1)定义学生和教师模型

# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YoloBody(num_det=config.DET_NUM_CLASSES, num_seg=config.SEG_NUM_CLASSES, phi=args.phi, task="multi", use_aspp=False)
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device)['model'])

(2)定义特征蒸馏损失

  • Mimic Loss
class MimicLoss(nn.Module):def __init__(self, channels_s, channels_t):super(MimicLoss, self).__init__()device = 'cuda' if torch.cuda.is_available() else 'cpu'self.mse = nn.MSELoss()def forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapelosses.append(self.mse(s, t))loss = sum(losses)return loss
  • CWD Loss
  • 参考:【知识蒸馏】feature-based 知识蒸馏 - - CWD(channel-wise knowledge dissillation)
class CWDLoss(nn.Module):"""PyTorch version of `Channel-wise Distillation for Semantic Segmentation.<https://arxiv.org/abs/2011.13256>`_."""def __init__(self, channels_s, channels_t,tau=1.0):super(CWDLoss, self).__init__()self.tau = taudef forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapeN, C, H, W = s.shape# normalize in channel diemensionsoftmax_pred_T = F.softmax(t.view(-1, W * H) / self.tau, dim=1)  # [N*C, H*W]logsoftmax = torch.nn.LogSoftmax(dim=1)cost = torch.sum(softmax_pred_T * logsoftmax(t.view(-1, W * H) / self.tau) - softmax_pred_T * logsoftmax(s.view(-1, W * H) / self.tau)) * (self.tau ** 2)losses.append(cost / (C * N))loss = sum(losses)return loss
  • MGD Loss
    参考:【知识蒸馏】feature-based 知识蒸馏 - - MGD(mask generative dissillation)
class MGDLoss(nn.Module):def __init__(self, channels_s, channels_t, alpha_mgd=0.00002, lambda_mgd=0.65):super(MGDLoss, self).__init__()device = 'cuda' if torch.cuda.is_available() else 'cpu'self.alpha_mgd = alpha_mgdself.lambda_mgd = lambda_mgdself.generation = [nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(channel, channel, kernel_size=3, padding=1)).to(device) for channel in channels_t]def forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapelosses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)loss = sum(losses)return lossdef get_dis_loss(self, preds_S, preds_T, idx):loss_mse = nn.MSELoss(reduction='sum')N, C, H, W = preds_T.shapedevice = preds_S.devicemat = torch.rand((N, 1, H, W)).to(device)mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)masked_fea = torch.mul(preds_S, mat)new_fea = self.generation[idx](masked_fea)dis_loss = loss_mse(new_fea, preds_T) / Nreturn dis_loss
  • Feature Loss
 class FeatureLoss(nn.Module):def __init__(self, channels_s, channels_t, distiller='cwd', loss_weight=1.0):super(FeatureLoss, self).__init__()self.loss_weight = loss_weightdevice = 'cuda' if torch.cuda.is_available() else 'cpu'self.align_module = nn.ModuleList([nn.Conv2d(channel, tea_channel, kernel_size=1, stride=1, padding=0).to(device)for channel, tea_channel in zip(channels_s, channels_t)])self.norm = [nn.BatchNorm2d(tea_channel, affine=False).to(device)for tea_channel in channels_t]if distiller == 'mimic':self.feature_loss = MimicLoss(channels_s, channels_t)elif distiller == 'mgd':self.feature_loss = MGDLoss(channels_s, channels_t)elif distiller == 'cwd':self.feature_loss = CWDLoss(channels_s, channels_t)else:raise NotImplementedErrordef forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):s = self.align_module[idx](s)s = self.norm[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss

(3)使用hook获取需要蒸馏的特征层

  • 定义回调函数
activation = {}
def get_activation(name):def hook(model, inputs, outputs):activation[name] = outputsreturn hook
  • 使用hook函数
def get_hooks():hooks = []# S-model#for k, v in teacher_model._modules.items():#     print(f"tmodel._modules_k: {k}; v: {v}")hooks.append(model._modules['backbone'].stem.register_forward_hook(get_activation("s_stem")))hooks.append(model._modules['backbone'].dark2.register_forward_hook(get_activation("s_dark2")))hooks.append(model._modules['backbone'].dark3.register_forward_hook(get_activation("s_dark3")))hooks.append(model._modules['backbone'].dark4.register_forward_hook(get_activation("s_dark4")))hooks.append(model._modules['backbone'].dark5.register_forward_hook(get_activation("s_dark5")))# T-modelhooks.append(teacher_model._modules['module'].backbone.stem.register_forward_hook(get_activation("t_stem")))hooks.append(teacher_model._modules['module'].backbone.dark2.register_forward_hook(get_activation("t_dark2")))hooks.append(teacher_model._modules['module'].backbone.dark3.register_forward_hook(get_activation("t_dark3")))hooks.append(teacher_model._modules['module'].backbone.dark4.register_forward_hook(get_activation("t_dark4")))hooks.append(teacher_model._modules['module'].backbone.dark5.register_forward_hook(get_activation("t_dark5")))return hooks
  • 获取需要蒸馏的挺特征层
stu_features = [activation["s_stem"], activation["s_dark2"], activation["s_dark3"],activation["s_dark4"], activation["s_dark5"]]
tea_features = [activation["t_stem"],activation["t_dark2"],activation["t_dark3"],activation["t_dark4"],activation["t_dark5"]]

(4)计算特征蒸馏损失

# 实例化特征蒸馏损失类
channels_s = [16,32,64,128,256]
channels_t = [32, 64, 128, 256,512]
distill_feat_type = 'mimic'
distill_loss = FeatureLoss(channels_s=channels_s, channels_t=channels_t,distiller=distill_feat_type)# 计算蒸馏损失
distill_weight = 1
dfea_loss = distill_loss(stu_features,tea_features)*distill_weight
print('---------dfea_loss---------- :', dfea_loss)

(5)计算总损失,反向传播

  • 计算总损失
Bev_loss = (Bev_det_distill_loss * config.Ratio_det + Bev_seg_distill_loss * config.Ratio_seg)*(1-distill_.mtl_feature_alpha) +dfea_loss* distill_.mtl_feature_alpha
  • 反向传播
Bev_loss.backward()

(6)保存蒸馏模型

  • 移除hook
 # -------- 移除hook,不然保存模型会报错 ---------#
for hook in hooks:hook.remove()    
  • 保存蒸馏模型
torch.save(model, distill_path)
print('save distill model')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/16927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

太狠了,凌晨5点面试。。

(关注数据结构和算法&#xff0c;了解更多新知识) 网上看到一网友发文说收到面试邀请&#xff0c;面试时间竟然是早晨5点&#xff0c;这是要猝死的节奏。有的网友说应该是下午 5 点&#xff0c;如果是下午 5 点直接写下午 5 点就行了&#xff0c;或者写 17 点也行&#xff0c;直…

[企业级高效系统工具]短视频矩阵系统 ,一站式管理新媒体账号,短视频精准获客,一键管理上千个短视频账。

一、做项目为什么要搭建一款属于自己的系统&#xff1f; 在讲这个短视频矩阵系统前&#xff0c;我们聊聊做项目的时候为什么要搭建一款属于自己的系统。 1.拥有自己的系统&#xff0c;就意味着你可以随时随地进行源码部署和更新。你的项目就能紧跟时代潮流&#xff0c;始终保持…

具身人工智能:人工智能机器人如何感知世界

什么是具身人工智能 虽然近年来机器人在智能城市、工厂和家庭中大量出现,但我们大部分时间都在与由传统手工算法控制的机器人互动。这些机器人的目标很狭隘,很少从周围环境中学习。相比之下,能够与物理环境互动并从中学习的人工智能 (AI) 代理(机器人、虚拟助手或其他智能系…

强化学习——学习笔记

一、什么是强化学习&#xff1f; 强化学习 (Reinforcement Learning, RL) 是一种通过与环境交互来学习决策策略的机器学习方法。它的核心思想是让智能体 (Agent) 在执行动作 (Action)、观察环境 (Environment) 反馈的状态 (State) 和奖励 (Reward) 的过程中&#xff0c;学习到…

【每日随笔】小人畏威不怀德 , 君子畏德不畏威 ( 先礼后兵 )

文章目录 一、小人畏威不怀德1、小人畏威不怀德2、小人场景一3、小人场景二 二、君子畏德不畏威三、先礼后兵 一、小人畏威不怀德 1、小人畏威不怀德 如果 友善 的对待 小人 , 这种人 认知低 且 素质差 , 小人 会将你的 " 友善 " 理解为 " 屈服 " , 他会认…

单片机方案开发个性定制

酷得智能是玩具企业合作方案商&#xff0c;致力于为玩具企业提供一站式的智能化解决方案。我们拥有丰富的行业经验和技术实力&#xff0c;能够根据客户的需求和市场趋势&#xff0c;为其量身定制最适合的智能玩具产品和解决方案。 主营业务&#xff1a; 东莞市酷得智能科技有限…

Sping源码(九)—— Bean的初始化(非懒加载)— ConversionService

序言 经过前面一系列的加载、解析等准备工作&#xff0c;此刻refresh方法的执行已经来到了尾声&#xff0c;接下来我们用几篇文章着重的介绍一下Bean的初始化 代码 着重看refresh()主流程中的finishBeanFactoryInitialization()方法。 finishBeanFactoryInitialization 方法…

JAVA开发 利用代码生成奖状

通过java实现用模板生成奖状 1、图片模板2、实现代码3、生成模板 1、图片模板 2、实现代码 import javax.imageio.ImageIO; import java.awt.*; import java.awt.font.TextAttribute; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException;…

CompositeDisposable作用

CompositeDisposable 是一个在 RxJava 中常用的类&#xff0c;它用于管理多个 Disposable 对象。Disposable 是 RxJava 中用于管理订阅&#xff08;subscription&#xff09;的接口&#xff0c;它允许我们取消订阅以避免内存泄漏和不必要的资源消耗。 CompositeDisposable 的主…

三坐标测量机在汽车零部件质量控制中的应用

高质量的零部件能够确保汽车的性能达到设计标准&#xff0c;包括动力性能、燃油效率、操控稳定性等&#xff0c;从而提供更好的驾驶体验&#xff0c;建立消费者对汽车品牌的信任&#xff1b;也推动了汽车行业的技术创新&#xff0c;制造商不断研发新材料、新工艺&#xff0c;以…

Java 登录错误次数限制,用户禁登1小时

手机号验证码登录&#xff0c;验证码输入错误次数超5次封禁 Overridepublic boolean checkCaptcha(String phoneNum, String captcha) {String codeNum (String) redisTemplate.opsForValue().get(UserCacheNames.USER_CAPTCHA phoneNum);if (codeNum null) {throw new Wan…

怎么图片转excel表格免费?介绍三个方法

怎么图片转excel表格免费&#xff1f;在日常工作中&#xff0c;我们经常需要将图片中的表格数据转化为可编辑的Excel格式。幸运的是&#xff0c;市面上有多款软件支持这一功能&#xff0c;并且部分软件还提供免费使用的选项。本文将为您详细介绍几款可以免费将图片转换为Excel表…

Java 异步编程——Java内置线程调度器(Executor 框架)

文章目录 Java多线程的两级调度模型Executor 框架Executor 框架的组成概念Executor 框架中任务执行的两个阶段&#xff1a;任务提交和任务执行 在 Java1.5 以前&#xff0c;开发者必须手动实现自己的线程池&#xff1b;从 Java1.5 开始&#xff0c;Java 内部提供了线程池。 在J…

Python代码:十九、列表的长度

1、题目 描述&#xff1a; 牛牛学会了使用list函数与split函数将输入的连续字符串封装成列表&#xff0c;你能够帮他使用len函数统计一些公输入了多少字符串&#xff0c;列表中有多少元素吗&#xff1f; 输入描述&#xff1a; 输入一行多个字符串&#xff0c;字符串之间通过…

基于Java+SpringBoot+Mybaties-plus+Vue+elememt + uniapp 驾校预约平台 的设计与实现

一.项目介绍 系统角色&#xff1a;管理员、教练、学员 小程序(仅限于学员注册、登录)&#xff1a; 查看管理员发布的公告信息 查看管理员发布的驾校信息 查看所有教练信息、预约(需教练审核)、评论、收藏喜欢的教练 查看管理员发布的考试信息、预约考试(需管理…

流媒体内网穿透/组网/视频协议转换EasyNTS上云网关如何更改密码?

EasyNTS上云网关的主要作用是解决异地视频共享/组网/上云的需求&#xff0c;网页对域名进行添加映射时&#xff0c;添加成功后会生成一个外网访问地址&#xff0c;在浏览器中输入外网访问地址&#xff0c;即可查看内网应用。无需开放端口&#xff0c;EasyNTS上云网关平台会向Ea…

什么是逆向抓包?是通过什么进行操作的?怎么解决?

前言 逆向抓包是一个技术过程&#xff0c;它涉及使用特定的网络分析工具来捕获和分析网络通信中的数据包。这个过程的目的通常是研究和理解网络协议、应用程序或系统的内部工作原理。以下是关于逆向抓包的详细解释及操作方式&#xff1a; 什么是逆向抓包&#xff1f; 逆向抓…

【linux】深入了解线程池:基本概念与代码实例(C++)

文章目录 1. 前言1.1 概念1.2 应用场景1.3 线程池的种类1.4 线程池的通常组成 2. 代码示例2.1 log.hpp2.2 lockGuard.hpp① pthread_mutex_t 2.3 Task.hpp2.4 thread.hpp2.5 threadPool.hpp① 基本框架② 成员变量③ 构造函数④ 其余功能函数&#xff1a; main.cc结果演示 完整…

数组与指针声明小问题

1、int *p &a; 是 C 语言中的一条语句&#xff0c;它涉及指针的声明和初始化。让我们逐步解释这一行代码的含义&#xff1a; int *p&#xff1a;这是一个指针声明。它声明了一个名为 p 的变量&#xff0c;该变量是一个指向 int 类型数据的指针。 &a&#xff1a;这是取…

动态规划-似包非包问题

组合总和 Ⅳ&#xff08;377&#xff09; 题目描述&#xff1a; 状态表示&#xff1a; 我们看到这题发现有一个限制条件就是目标整数target并且此时数组中的数字是可以重复选择的&#xff0c;这时候不难联想到前面学习的完全背包问题&#xff0c;这题好像符合完全背包问题的…