基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(二)之知识收集阶段

前面学习了模型的构建与训练过程,然而在实验过程中,博主依旧对数据集与模型之间的关系有些疑惑,首先是论文说这是一个混合数据集,但事实上博主在实验时是将三个数据集分开的,那么在数据读取时是如何混合的呢,是每个epoch使用同一个数据集,下一个epoch再换数据集,还是再epoch中随机取数据集中的一部分。
此外,教师模型总共有三个,其模型构造是完全相同的,不同之处在于三个教师模型是在不同的数据集训练得到的,即其权重参数是固定的,那么在训练过程中,从代码来看,原始的教师网络权重是不改变的,那么说如何更新学生网络呢?带着这些疑问,开始今天的学习。

数据集加载

首先需要明确的是数据集加载时是将三个数据集进行了合并,只不过会按照三个数据集进行区别,即生成list形式。train_loader的相关参数设置如下:

在这里插入图片描述

模型训练

模型的训练分为两个阶段,分别是知识收集阶段与知识检验阶段,即knowlwdge collect(kc)knowledge exam(ke)两阶段。

在这里插入图片描述

在开始前,需要声明必须要将batch-size设置为3以上,否则会无法加载数据集
首先是知识收集阶段:
声明损失函数,这里的损失函数有两个,分别是L1损失与通过VGG网络计算的软损失(SCRLoss)

criterion_l1, criterion_scr, _ = criterions

在这里插入图片描述

模型开启traineval,关于两者的区别:

model.train()的作用是启用 Batch NormalizationDropout。在train模式,Dropout层会按照设定的参数p设置保留激活单元的概率,如keep_prob=0.8,Batch Normalization层会继续计算数据的meanvar并进行更新。
model.eval()的作用是不启用 Batch NormalizationDropout。在eval模式下,Dropout层会让所有的激活单元都通过,而Batch Normalization层会停止计算和更新meanvar,直接使用在训练阶段已经学出的meanvar值。在使用model.eval()时就是将模型切换到测试模式,在这里,模型就不会像在训练模式下一样去更新权重。
但是需要注意的是model.eval()不会影响各层的梯度计算行为,即会和训练模式一样进行梯度计算和存储,只是不进行反向传播。

model.train()#  model开启train
ckt_modules.train()
for teacher_network in teacher_networks:#为教师网络开启eval()teacher_network.eval()

随后便进入核心代码模块了:这里包含模型运算,特征映射,损失计算等过程
这里我们对应论文的创新点来看代码。
首先是进度条加载,这里是对数据集加载train_load的封装

pBar = tqdm(train_loader, desc='Training')

遍历数据,判断数据是否为空,这里曾经困扰过博主一段时间,因为每次遍历时target_image都为空,只要将batch-size设置为3以上即可。

for target_images, input_images in pBar:if target_images is None: continuetarget_images = target_images.cuda()input_images = [images.cuda() for images in input_images]preds_from_teachers = []

可以看到,此时已经将输入图像,目标图像转换为tensor格式,其中input_imageslist形式,每张图像为torch.Size([1, 3, 224, 224])

在这里插入图片描述

而target_images为完全为tensor格式,shape为torch.Size([3, 3, 224, 224])

在这里插入图片描述

简要描述知识收集阶段

teacher_networks即为教师网络列表,单个的教师网络模型与学生网络是相同的,将数据输入教师网络时,由于需要使用教师网络的中间特征,因此return_feat为True,最终的输出结果为预测结果图与中间特征图,预测结果图会作为 “真值” 来训练学生网络,并计算软损失,中间特征图会与学生网络进行映射到同一特征域来进行特征转移,并将教师网络的预测结果与学生网络的预测结果求SCRLoss。

preds_from_teachers = []
features_from_each_teachers = []
with torch.no_grad():
for i in range(len(teacher_networks)):preds, features = teacher_networks[i](input_images[i], return_feat=True)preds_from_teachers.append(preds)features_from_each_teachers.append(features)		

随后将图像输入教师模型,教师模型不更新权重,只是用模型输出的特征来帮助学生网络来训练,称为软损失。核心代码如下:

preds, features = teacher_networks[i](input_images[i], return_feat=True)

将图像 i 输入对应的教师网络 i,这里的i指的是教师网络的索引,这里博主开始曾经有过疑惑,此时的batch_size为3,刚好与教师网络数量对应,因此可以使用该网络,那如果batch_size为6,9时呢,后面的岂不是都无法输入模型了吗,随后博主将batch_size改为6,发现此时的input_image依旧是list形式,但每个list中的内容已经发生了改变,可以看到其是按照不同的数据集类型做了区分,这就是为何input_image要使用listtarget_imagetensor的原因了。现在之前的疑惑也就消失了。
在这里插入图片描述

随后获得输出结果pred,即预测结果,也就是恢复后的图像。可以看到其与输入图像的维度是一致的,对于第一个网络的第一组输入图像,都为:torch.Size([3, 3, 224, 224])

在这里插入图片描述
而返回的中间特征图像如图所示,可以看到输出的不同大小的特征图,总共有4组,即4组不同大小的特征图,每组3张图像,通道数,宽高则不相同。
第 1 组数据集(教师网络)的中间特征图:

在这里插入图片描述
第 3 组数据集(教师网络)的中间特征图:

在这里插入图片描述

随后经过三个网络模型的运算,将结果加入列表:

preds_from_teachers.append(preds)
features_from_each_teachers.append(features)

在这里插入图片描述
在这里插入图片描述

随后将教师网络的预测值转换为tensor格式,因为在最终学生网络的输出是tensor

preds_from_teachers = torch.cat(preds_from_teachers)

原本list变为tensor
在这里插入图片描述
接下来这段是对feature按照特征图大小进行分组,现在的特征图是按照数据集划分为3组,为方便后面做特征映射,将其按照特征图大小分为四组。

for layer in range(len(features_from_each_teachers[0])):features_from_teachers.append([features_from_each_teachers[i][layer] for i in range(len(teacher_networks))])

在这里插入图片描述

随后便是将输入图像输入学生网络输出结果与中间特征图,这里是不区分数据集的,完全是混合的

preds_from_student, features_from_student = model(torch.cat(input_images), return_feat=True)

由于博主将batch设置为6会报显存溢出,因此这里改为4,可以看到中间特征图依旧是四组,不过每组的第一个值由6变为了4,其余都没有改变。
可以看到list为4组,代表4组不同尺度特征图,每组里面又有一个list,每个list中包含不同数据集(教师网络的特征图)分别是2,1,1。

在这里插入图片描述
同理输出结果也是由6变4。

在这里插入图片描述

CKT模块(特征转移)

随后便是中间特征图映射了,其过程其实也很简单,即将教师网络特征如与学生网络特征图同时输入CKT模型中,并获得输出结果,将输出结果做损失即可。
在这里插入图片描述

PFE_loss, PFV_loss = 0., 0.
for i, (s_features, t_features) in enumerate(zip(features_from_student, features_from_teachers)):t_proj_features, t_recons_features, s_proj_features = ckt_modules[i](t_features, s_features)PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

可以看到输入的教师网络特征与学生网络特征也不是相同格式的:

在这里插入图片描述
输入值:
经过遍历后,学生网络的特征图分为四组,分别对应不同尺度的特征图,但没有区分数据集,因为本身学生网络就是不区分数据集的。
在这里插入图片描述
而教师网络却是list形式,每个数据集分别对应2,1,1个图像数量
在这里插入图片描述
CKT网络定义:

class CKTModule(nn.Module):def __init__(self, channel_t, channel_s, channel_h, n_teachers):super().__init__()self.teacher_projectors = TeacherProjectors(channel_t, channel_h, n_teachers)self.student_projector = StudentProjector(channel_s, channel_h)def forward(self, teacher_features, student_feature):teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)student_projected_feature = self.student_projector(student_feature)return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature

具体结构如下,CKT模块共有4个,即对应不同尺度的特征图,注意功能便是进行一系列的特征映射与转换。

CKTModule((teacher_projectors): TeacherProjectors((PFPs): ModuleList((0): Sequential((0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(1): Sequential((0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(2): Sequential((0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(IPFPs): ModuleList((0): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(1): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(2): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))))(student_projector): StudentProjector((PFP): Sequential((0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(1): ReLU(inplace=True)(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))))

特征转移实际上也是通过损失函数来进行的,即通过一个网学习特征,从而达到特征转移的效果。
最终获得三个结果,分别是教师网络结构特征,教师网络重构特征,学生网络结构特征。核心代码如下:

teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)
student_projected_feature = self.student_projector(student_feature)
return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature

输出值:
与输入值一样,学生网络结构特征的输出值为tensor形式

在这里插入图片描述
而教师网络特征与教师网络重构特征的输出值依旧为list形式。

在这里插入图片描述

在这里插入图片描述

随后求特征损失与重构损失即可。

PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

在这里插入图片描述

在这里插入图片描述

最终求总损失与SCR损失即可,值得注意的是SCR损失需要使用VGG网络做特征变换后再计算。
L1损失较为简单,输入为学生网络预测值与教师网络预测值

T_loss = criterion_l1(preds_from_student, preds_from_teachers)
SCR_loss = 0.1 * criterion_scr(preds_from_student, target_images, torch.cat(input_images))

关于criterion_l1函数,其实际上是首先使用VGG网络进行特征变换,其输入数据分别是学生网络预测值,目标图像以及输入图像。
SCRLoss定义如下:根据在forward中的代码可知,其首先将输入值分别输入VGG网络进行特征变换,随后在将输出值计算L1损失。
其中,detch方法是返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_gradfalse,得到的这个tensor永远不需要计算其梯度,不具有grad。即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()tensor就会停止,不能再继续向前进行传播。
最终乘以对应的权重,返回最后的损失。

class SCRLoss(nn.Module):def __init__(self):super().__init__()self.vgg = Vgg19().cuda()self.l1 = nn.L1Loss()self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]def forward(self, a, p, n):a_vgg, p_vgg, n_vgg = self.vgg(a), self.vgg(p), self.vgg(n)loss = 0d_ap, d_an = 0, 0for i in range(len(a_vgg)):d_ap = self.l1(a_vgg[i], p_vgg[i].detach())d_an = self.l1(a_vgg[i], n_vgg[i].detach())contrastive = d_ap / (d_an + 1e-7)loss += self.weights[i] * contrastivereturn loss

可以看到最后的损失值是Tensor形式的。

在这里插入图片描述

至此,知识收集阶段便完成了。接下来便是知识测试阶段。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95594.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

林沛满-Wireshark的提示

本文整理自:《Wireshark网络分析的艺术 第1版》 作者:林沛满 著 出版时间:2016-02 最近有不少同事开始学习 Wireshark,他们遇到的第一个困难就是理解不了主界面上的提示信息,于是跑来问我。问的人多了,我也…

如何一步步优化负载均衡策略

发展到一定阶段后,Web 应用程序就会增长到单服务器部署无法承受的地步。这时候企业要么提升可用性,要么提升可扩展性,甚至两者兼而有之。为此,他们会将应用程序部署在多台服务器上,并在服务器之前使用负载均衡器来分配…

华为云云耀云服务器L实例评测|SpringCloud相关组件——nacos和sentinel的安装和配置 运行内存情况 服务器被非法登陆尝试的解决

前言 最近华为云云耀云服务器L实例上新,也搞了一台来玩,期间遇到各种问题,在解决问题的过程中学到不少和运维相关的知识。 本篇博客介绍SpringCloud相关组件——nacos和sentinel的安装和配置,并分析了运行内存情况,此…

数据在内存中的存储(1)

文章目录 目录1. 数据类型介绍1.1 类型的基本归类 2. 整形在内存中的存储2.1 原码、反码、补码2.2 大小端介绍2.3 练习 附: 目录 数据类型介绍整形在内存中的存储大小端字节序介绍及判断浮点型在内存中的存储 1. 数据类型介绍 前面我们已经学习了基本的内置类型以…

Ubuntu20配置Mysql常用操作

文章目录 版权声明ubuntu更换软件源Ubuntu设置静态ipUbuntu防火墙ubuntu安装ssh服务Ubuntu安装vmtoolsUbuntu安装mysql5.7Ubuntu安装mysql8.0Ubuntu卸载mysql 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明,所有版权属于黑马程…

回归预测 | MATLAB实现PSO-SVR粒子群优化支持向量机回归多输入单输出预测

回归预测 | MATLAB实现PSO-SVR粒子群优化支持向量机回归多输入单输出预测 目录 回归预测 | MATLAB实现PSO-SVR粒子群优化支持向量机回归多输入单输出预测预测效果基本介绍模型描述程序设计预测效果 <

lv7 嵌入式开发-网络编程开发 11 TCP管理与UDP协议

目录 1 TCP管理 1.1 三次握手 1.2 四次挥手 1.3 保活计时器 2 wireshark安装及实验 3.1 icmp协议抓包演示 3.2 tcp协议抓包演示 3 UDP协议 3.1 UDP 的主要特点&#xff1a; 4 练习 1 TCP管理 1.1 三次握手 TCP 建立连接的过程叫做握手。 采用三报文握手&#xff1…

Fiddler抓取手机https包的步骤

做接口测试时&#xff0c;有时我们需要使用fiddler进行抓包分析&#xff0c;那么如何抓取https包。主要分为以下七步&#xff1a; 1.设置fiddler选项&#xff1a;Tools->Options,按如下图勾选 2.下载并安装Fiddler证书生成器 下载地址&#xff1a;http://www.telerik.com/…

C/C++——内存管理

1.为什么存在动态内存分配 灵活性 静态内存分配是在编译时确定的&#xff0c;程序执行过程中无法改变所分配的内存大小&#xff1b;动态内存分配可以根本程序的运行环境来动态分配和释放空间&#xff0c;提供了更大的灵活性 动态数据结构 有些数据结构的大小和结构在编译时…

[计算机入门] Windows附件程序介绍(工具类)

3.14 Windows附件程序介绍(工具类) 3.14.1 计算器 Windows系统中的计算器是一个内置的应用程序&#xff0c;提供了基本的数学计算功能。它被设计为一个方便、易于使用的工具&#xff0c;可以满足用户日常生活和工作中的基本计算需求。 以下是计算器程序的主要功能&#xff1a…

【算法基础】基础算法(二)--(高精度、前缀和与差分)

一、高精度 当一个数很大&#xff0c;大到 int 无法存下时&#xff0c;我们可以考虑用数组来进行存储&#xff0c;即数组中一个位置存放一位数。 但是对于数组而言&#xff0c;一个数顺序存入数组后&#xff0c;对其相加减是很简单的。但是当需要进位时&#xff0c;还是很麻烦的…

华为云云耀云服务器L实例评测|部署个人音乐流媒体服务器 navidrome

华为云云耀云服务器L实例评测&#xff5c;部署个人音乐流媒体服务器 navidrome 一、云耀云服务器L实例介绍1.1 云服务器介绍1.2 产品规格1.3 产品优势1.4 支持镜像 二、云耀云服务器L实例配置2.1 重置密码2.2 服务器连接2.3 安全组配置 三、部署 navidrome3.1 navidrome 介绍3.…

【动手学深度学习-Pytorch版】Transformer代码总结

本文是纯纯的撸代码讲解&#xff0c;没有任何Transformer的基础内容~ 是从0榨干Transformer代码系列&#xff0c;借用的是李沐老师上课时讲解的代码。 本文是根据每个模块的实现过程来进行讲解的。如果您想获取关于Transformer具体的实现细节&#xff08;不含代码&#xff09;可…

ElasticSearch - 基于 拼音分词器 和 IK分词器 模拟实现“百度”搜索框自动补全功能

目录 一、自动补全 1.1、效果说明 1.2、安装拼音分词器 1.3、自定义分词器 1.3.1、为什么要自定义分词器 1.3.2、分词器的构成 1.3.3、自定义分词器 1.3.4、面临的问题和解决办法 问题 解决方案 1.4、completion suggester 查询 1.4.1、基本概念和语法 1.4.2、示例…

Ubuntu Server CLI专业提示

基础 网络 获取所有接口的IP地址 networkctl status 显示主机的所有IP地址 hostname -I 启用/禁用接口 ip link set <interface> up ip link set <interface> down 显示路线 ip route 将使用哪条路线到达主机 ip route get <IP> 安全 显示已登录的用户 w…

PLL锁相环倍频原理

晶振8MHz&#xff0c;但是处理器输入可以达到72MHz&#xff0c;是因为PLL锁相环提供了72MHz。 锁相环由PD&#xff08;鉴相器&#xff09;、LP&#xff08;滤波器&#xff09;、VCO&#xff08;压控振荡器&#xff09;组成。 处理器获得的72MHz并非晶振提供&#xff0c;而是锁…

好工具分享:阿里云价格计算器_一键计算精准报价

阿里云服务器价格计算器&#xff0c;鼠标选择云服务器ECS实例规格、地域、系统盘、带宽及购买时长即可一键计算出精准报价&#xff0c;阿里云服务器网分享阿里云服务器价格计算器链接地址&#xff1a; 阿里云服务器价格计算器 先打开阿里云服务器ECS页面 aliyunfuwuqi.com/go…

生成Release版本的.pdb文件

软件分为Debug版本、Release版本这2种版本&#xff0c;其中Debug版本是带有.pdb调试信息文件&#xff0c;而Release版本不带.pdb调试信息文件。软件发布时&#xff0c;一般采用Release版本&#xff0c;若因内存泄漏、数组访问越界、除零错误、磁盘读写错误等异常&#xff0c;造…

计算机毕设 大数据房价预测分析与可视

文章目录 0 前言1 课题背景2 导入相关的数据 3 观察各项主要特征与房屋售价的关系4 最后 0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两年不断有学弟…

CleanMyMac X4.14.1最新版本下载

CleanMyMac X是一个功能强大的Mac清理软件&#xff0c;它的设计理念是提供多个模块&#xff0c;包括垃圾清理、安全保护、速度优化、应用程序管理和文档管理粉碎等&#xff0c;以满足用户的不同需求。软件的界面简洁直观&#xff0c;让用户能够轻松进行日常的清理操作。 使用C…