本文使用RKD实现对deeplabv3+模型的蒸馏;与上一篇KD蒸馏的方法有所不同,RKD是对展平层的特征做蒸馏,蒸馏的loss分为二阶的距离损失Distance-wise Loss和三阶的角度损失Angle-wise Loss。
一、RKD简介
RKD算法的核心是以教师模型的多个输出为结构单元,取代传统蒸馏学习中以教师模型的单个输出的方式,利用多输出组合成结构单元,更能体现出教师模型的结构化特征,使得学生模型得到更好的指导。
关系型蒸馏学习的损失函数如下,其中t1,t2…tn表示教师模型的多个输出,s1,s2…sn表示学生模型的多个输出,L表示计算两者之间的距离。与传统的蒸馏学习不同,关系型蒸馏学习的损失函数中还有一个构件结构信息的函数。可以使得学生模型学到教师模型中更加高效的信息表征能力。本文提出了两种表征结构信息的损失:距离蒸馏损失和角度蒸馏损失。
距离蒸馏损失:
通过对每个batch中的样本进行两两距离计算,最终形成一个batch*batch大小的关系型结构输出。最终学生模型通过学习教师模型的结构输出,实现蒸馏学习。
角度蒸馏损失:
基于角度的蒸馏损失,通过对每个batch中的样本三三样本,计算两个角度,最终形成一个batchbatchbatch大小的关系型结构输出。最终学生模