一、实现流程
(1)定义学生和教师模型
(2)定义特征蒸馏损失
- Mimic Loss
- CWD Loss
- MGD Loss
- Feature Loss
(3)使用hook获取需要蒸馏的特征层
- 定义回调函数
- 使用hook函数
- 获取需要蒸馏的挺特征层
(4)计算特征蒸馏损失
(5)计算总损失,反向传播
- 计算总损失
- 反向传播
(6)保存蒸馏模型
- 移除hook
- 保存蒸馏模型
二、代码实现
(1)定义学生和教师模型
# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YoloBody(num_det=config.DET_NUM_CLASSES, num_seg=config.SEG_NUM_CLASSES, phi=args.phi, task="multi", use_aspp=False)
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device)['model'])
(2)定义特征蒸馏损失
- Mimic Loss
class MimicLoss(nn.Module):def __init__(self, channels_s, channels_t):super(MimicLoss, self).__init__()device = 'cuda' if torch.cuda.is_available() else 'cpu'self.mse = nn.MSELoss()def forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapelosses.append(self.mse(s, t))loss = sum(losses)return loss
- CWD Loss
- 参考:【知识蒸馏】feature-based 知识蒸馏 - - CWD(channel-wise knowledge dissillation)
class CWDLoss(nn.Module):"""PyTorch version of `Channel-wise Distillation for Semantic Segmentation.<https://arxiv.org/abs/2011.13256>`_."""def __init__(self, channels_s, channels_t,tau=1.0):super(CWDLoss, self).__init__()self.tau = taudef forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapeN, C, H, W = s.shape# normalize in channel diemensionsoftmax_pred_T = F.softmax(t.view(-1, W * H) / self.tau, dim=1) # [N*C, H*W]logsoftmax = torch.nn.LogSoftmax(dim=1)cost = torch.sum(softmax_pred_T * logsoftmax(t.view(-1, W * H) / self.tau) - softmax_pred_T * logsoftmax(s.view(-1, W * H) / self.tau)) * (self.tau ** 2)losses.append(cost / (C * N))loss = sum(losses)return loss
- MGD Loss
参考:【知识蒸馏】feature-based 知识蒸馏 - - MGD(mask generative dissillation)
class MGDLoss(nn.Module):def __init__(self, channels_s, channels_t, alpha_mgd=0.00002, lambda_mgd=0.65):super(MGDLoss, self).__init__()device = 'cuda' if torch.cuda.is_available() else 'cpu'self.alpha_mgd = alpha_mgdself.lambda_mgd = lambda_mgdself.generation = [nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(channel, channel, kernel_size=3, padding=1)).to(device) for channel in channels_t]def forward(self, y_s, y_t):"""Forward computation.Args:y_s (list): The student model prediction withshape (N, C, H, W) in list.y_t (list): The teacher model prediction withshape (N, C, H, W) in list.Return:torch.Tensor: The calculated loss value of all stages."""assert len(y_s) == len(y_t)losses = []for idx, (s, t) in enumerate(zip(y_s, y_t)):assert s.shape == t.shapelosses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)loss = sum(losses)return lossdef get_dis_loss(self, preds_S, preds_T, idx):loss_mse = nn.MSELoss(reduction='sum')N, C, H, W = preds_T.shapedevice = preds_S.devicemat = torch.rand((N, 1, H, W)).to(device)mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)masked_fea = torch.mul(preds_S, mat)new_fea = self.generation[idx](masked_fea)dis_loss = loss_mse(new_fea, preds_T) / Nreturn dis_loss
- Feature Loss
class FeatureLoss(nn.Module):def __init__(self, channels_s, channels_t, distiller='cwd', loss_weight=1.0):super(FeatureLoss, self).__init__()self.loss_weight = loss_weightdevice = 'cuda' if torch.cuda.is_available() else 'cpu'self.align_module = nn.ModuleList([nn.Conv2d(channel, tea_channel, kernel_size=1, stride=1, padding=0).to(device)for channel, tea_channel in zip(channels_s, channels_t)])self.norm = [nn.BatchNorm2d(tea_channel, affine=False).to(device)for tea_channel in channels_t]if distiller == 'mimic':self.feature_loss = MimicLoss(channels_s, channels_t)elif distiller == 'mgd':self.feature_loss = MGDLoss(channels_s, channels_t)elif distiller == 'cwd':self.feature_loss = CWDLoss(channels_s, channels_t)else:raise NotImplementedErrordef forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):s = self.align_module[idx](s)s = self.norm[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss
(3)使用hook获取需要蒸馏的特征层
- 定义回调函数
activation = {}
def get_activation(name):def hook(model, inputs, outputs):activation[name] = outputsreturn hook
- 使用hook函数
def get_hooks():hooks = []# S-model#for k, v in teacher_model._modules.items():# print(f"tmodel._modules_k: {k}; v: {v}")hooks.append(model._modules['backbone'].stem.register_forward_hook(get_activation("s_stem")))hooks.append(model._modules['backbone'].dark2.register_forward_hook(get_activation("s_dark2")))hooks.append(model._modules['backbone'].dark3.register_forward_hook(get_activation("s_dark3")))hooks.append(model._modules['backbone'].dark4.register_forward_hook(get_activation("s_dark4")))hooks.append(model._modules['backbone'].dark5.register_forward_hook(get_activation("s_dark5")))# T-modelhooks.append(teacher_model._modules['module'].backbone.stem.register_forward_hook(get_activation("t_stem")))hooks.append(teacher_model._modules['module'].backbone.dark2.register_forward_hook(get_activation("t_dark2")))hooks.append(teacher_model._modules['module'].backbone.dark3.register_forward_hook(get_activation("t_dark3")))hooks.append(teacher_model._modules['module'].backbone.dark4.register_forward_hook(get_activation("t_dark4")))hooks.append(teacher_model._modules['module'].backbone.dark5.register_forward_hook(get_activation("t_dark5")))return hooks
- 获取需要蒸馏的挺特征层
stu_features = [activation["s_stem"], activation["s_dark2"], activation["s_dark3"],activation["s_dark4"], activation["s_dark5"]]
tea_features = [activation["t_stem"],activation["t_dark2"],activation["t_dark3"],activation["t_dark4"],activation["t_dark5"]]
(4)计算特征蒸馏损失
# 实例化特征蒸馏损失类
channels_s = [16,32,64,128,256]
channels_t = [32, 64, 128, 256,512]
distill_feat_type = 'mimic'
distill_loss = FeatureLoss(channels_s=channels_s, channels_t=channels_t,distiller=distill_feat_type)# 计算蒸馏损失
distill_weight = 1
dfea_loss = distill_loss(stu_features,tea_features)*distill_weight
print('---------dfea_loss---------- :', dfea_loss)
(5)计算总损失,反向传播
- 计算总损失
Bev_loss = (Bev_det_distill_loss * config.Ratio_det + Bev_seg_distill_loss * config.Ratio_seg)*(1-distill_.mtl_feature_alpha) +dfea_loss* distill_.mtl_feature_alpha
- 反向传播
Bev_loss.backward()
(6)保存蒸馏模型
- 移除hook
# -------- 移除hook,不然保存模型会报错 ---------#
for hook in hooks:hook.remove()
- 保存蒸馏模型
torch.save(model, distill_path)
print('save distill model')