知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

dataloader_tools.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoaderdef load_data():# 载入MNIST训练集train_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=True,transform=transforms.ToTensor(),download=True)# 载入MNIST测试集test_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=False,transform=transforms.ToTensor(),download=True)# 生成训练集和测试集的dataloadertrain_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)return train_dataloader,test_dataloader

models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,1200)self.fc2 = nn.Linear(1200,1200)self.fc3 = nn.Linear(1200,num_classes)self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return xclass StudentModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,20)self.fc2 = nn.Linear(20,20)self.fc3 = nn.Linear(20,num_classes)def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as Fdef train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device):# ----------------------开始计时-----------------------------------start_time = time.time()# 设置参数开始训练best_acc, best_epoch = 0, 0criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 训练集上训练模型权重for data, targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 前向传播preds = model(data)loss = criterion(preds, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试集上评估模型性能model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()if acc > best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc),f'loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},最优参数已经保存到:weights/{model_name}_best_acc_params.pth')# -------------------------结束计时------------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device):# -------------------------------------开始计时--------------------------------start_time = time.time()# 定以损失函数hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction="batchmean")# 定义优化器optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)best_acc,best_epoch = 0,0for epoch in range(epochs):student_model.train()# 训练集上训练模型权重for data,targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)# 计算hard_lossstudent_hard_loss = hard_loss(student_preds,targets)# 计算蒸馏后的预测结果及soft_lossditillation_loss = soft_loss(F.softmax(student_preds/temp,dim=1),F.softmax(teacher_preds/temp,dim=1))# 将hard_loss和soft_loss加权求和loss = temp * temp * alpha * student_hard_loss + (1-alpha)*ditillation_loss# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()#测试集上评估模型性能student_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x,y in test_dataloader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions ==y).sum()num_samples += predictions.size(0)acc = (num_correct/num_samples).item()if acc>best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")student_model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))print(f'student_hard_loss={student_hard_loss},ditillation_loss={ditillation_loss},loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},')# --------------------------------结束计时----------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')

训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import dataloader_tools
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = dataloader_tools.load_data()# 定义教师模型
model = models.TeacherModel()
model = model.to(device)
# 打印模型的参数
summary(model)# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'teacher'
train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,值为:0.9868999719619751

用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import dataloader_tools
import models
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = dataloader_tools.load_data()# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

用知识蒸馏的方法训练student model

import torch
import train_tools
import models
import dataloader_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 加载数据
train_dataloader,test_dataloader = dataloader_tools.load_data()# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()# 开始训练
lr = 0.0001
epochs = 20
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
model_name = 'distill_student_loss'
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device)
最优准确率的epoch为9,值为:0.9204999804496765,
训练用时为:2.14minutes

在这里插入图片描述

loss改为:

# temp的平方乘在student_hard_loss
loss = temp * temp * alpha * student_hard_loss + (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9336999654769897,
训练用时为:2.12minutes

loss改为:

# temp的平方乘ditillation_loss
loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9176999926567078,
训练用时为:2.09minutes

上面的几种loss,蒸馏损失都出现了负数的情况。不太对劲。
在这里插入图片描述

其它开源的知识蒸馏算法如下:

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES6中的类

1、Class 类是一种可选&#xff08;而不是必须&#xff09;的设计模式&#xff0c;而且在 JavaScript 这样的 [[Prototype]] 语言中实现类是很别扭的。大致解决了以下问题&#xff1a; 不再引用杂乱的 .prototype 了Button 声 明 时 直 接“ 继 承 ” 了 Widget&#xff0c; …

Docker容器化平台及其优势和应用场景介绍

Docker是一种开源的容器化平台&#xff0c;它基于操作系统级别虚拟化技术&#xff0c;可以将应用程序及其依赖项打包成一个独立的容器&#xff0c;提供轻量级、一致性、可移植性的应用环境。Docker的基本概念和优势如下&#xff1a; 镜像(Image)&#xff1a;Docker容器的基础&…

JAVA对象转HashMap如何快速强转

小编我最近看到了些资料,之前JAVA对象转DTO都是另外写类&#xff0c;进行强转&#xff0c;里面有些Long类型&#xff0c;日期类型&#xff0c;都是转成String类型&#xff0c;现在有快速优雅得解决方式 我们需要得Map结构是 Map<String,object> 简单说明一下这个类&…

QT Day01 qt概述,创建项目,窗口属性,按钮,信号与槽

1.qt概述 1.什么是qt Qt 是一个跨平台的 C 图形用户界面应用程序框架。它为应用程序开发者提供建立艺 术级图形界面所需的所有功能。它是完全面向对象的&#xff0c;很容易扩展&#xff0c;并且允许真正的组 件编程。 2.支持的平台 Windows – XP 、 Vista 、 Win7 、 Win8…

Java(119):ExcelUtil工具类(org.apache.poi读取和写入Excel)

ExcelUtil工具类(XSSFWorkbook读取和写入Excel),入参和出参都是:List<Map<String,Object>> 一、读取Excel testdata.xlsx 1、new XSSFWorkbook对象 File file = new File(filePath); FileInputStream fis = new FileInputStream(file);…

8.二维数组——将一个二维数组行和列的元素互换,存到另一个二维数组中。

文章目录 前言一、题目描述 二、题目分析 三、解题 程序运行代码 前言 本系列为二维数组编程题&#xff0c;点滴成长&#xff0c;一起逆袭。 一、题目描述 将一个二维数组行和列的元素互换&#xff0c;存到另一个二维数组中。 二、题目分析 三、解题 程序运行代码 #incl…

玄学调参实践篇 | 深度学习模型 + 预训练模型 + 大模型LLM

&#x1f60d; 这篇主要简单记录一些调参实践&#xff0c;无聊时会不定期更新~ 文章目录 0、学习率与batch_size判断1、Epoch数判断2、判断模型架构是否有问题3、大模型 - 计算量、模型、和数据大小的关系4、大模型调参相关论文经验总结5、训练时模型的保存 0、学习率与batch_s…

python中几次方怎么打,三种内置方法

Python中几次方的三种内置方法 Python中至少内置的三种可以用于求取某个底数的几次方的方法&#xff0c;如下&#xff1a; 第一种方法&#xff0c;通过Python内置的幂次方运算符“**”&#xff1b;使用math模块的pow()方法&#xff0c;可以用于求取幂次方&#xff0c;即pow()…

压力测试+接口测试

jmeter是apache公司基于java开发的一款开源压力测试工具&#xff0c;体积小&#xff0c;功能全&#xff0c;使用方便&#xff0c;是一个比较轻量级的测试工具&#xff0c;使用起来非常简单。因 为jmeter是java开发的&#xff0c;所以运行的时候必须先要安装jdk才可以。jmeter是…

鸿蒙系统开发手册 - HarmonyOS内核驱动层源码分析

众所周知系统定义HarmonyOS是一款“面向未来”、面向全场景&#xff08;移动办公、运动健康、社交通信、媒体娱乐等&#xff09;的分布式操作系统。在传统的单设备系统能力的基础上&#xff0c;HarmonyOS提出了基于同一套系统能力、适配多种终端形态的分布式理念&#xff0c;能…

Arrays.asList() 与 Collections.singletonList()的恩怨情仇

1. 概述 列表是我们使用 Java 时常用的集合类型。 众所周知&#xff0c;我们可以轻松地用一行初始化一个List。例如&#xff0c;当我们想要初始化一个只有一个元素的List时&#xff0c;我们可以使用Arrays.asList()方法或Collections.singletonList()方法。 在本文中&#x…

【Linux】基础IO--文件基础知识/文件操作/文件描述符

文章目录 一、文件相关基础知识二、文件操作1.C语言文件操作2.操作系统文件操作2.1 比特位传递选项2.2 文件相关系统调用2.3 文件操作接口的使用 三、文件描述符fd1.什么是文件描述符2.文件描述符的分配规则 一、文件相关基础知识 我们对文件有如下的认识&#xff1a; 1.文件 …

用最少数量的箭引爆气球[中等]

优质博文&#xff1a;IT-BLOG-CN 一、题目 有一些球形气球贴在一堵用XY平面表示的墙面上。墙面上的气球记录在整数数组points&#xff0c;其中points[i] [xstart, xend]表示水平直径在xstart和xend之间的气球。你不知道气球的确切y坐标。一支弓箭可以沿着x轴从不同点完全垂直…

vue-动态组件、keep-alive

vue-动态组件、keep-alive 如果我们想写一个tabbar导航栏&#xff0c;我能想到的两种方式 通过if条件判断的方式实现&#xff08;不赘述&#xff09;动态组件 接下来我们就看看动态组件如何创建&#xff0c;废话不多少直接上代码&#xff08;代码中有备注&#xff09; 首先…

Panalog 日志审计系统 前台RCE漏洞复现

0x01 产品简介 Panalog是一款日志审计系统&#xff0c;方便用户统一集中监控、管理在网的海量设备。 0x02 漏洞概述 Panalog日志审计系统 sy_query.php接口处存在远程命令执行漏洞&#xff0c;攻击者可执行任意命令&#xff0c;接管服务器权限。 0x03 复现环境 FOFA&#xf…

谭巍主任专业角度解读:疣体脱落前的症状是什么?

我们时常会发现身体各个部位长出一些赘生物&#xff0c;有些属于皮肤良性改变&#xff0c;而有些则是病毒引起的&#xff0c;称之为疣体。然而在疣体脱落之前&#xff0c;通常会出现一些症状&#xff0c;这些症状可能因人而异&#xff0c;但以下是一些常见的迹象&#xff1a; 1…

如何隐藏选择选项值并用新值替换2个选项?

要隐藏选择选项值并用新值替换2个选项&#xff0c;可以使用JavaScript来实现。 首先&#xff0c;使用JavaScript获取两个选项的值&#xff0c;然后将这两个值设置为新的值&#xff0c;最后将这两个选项的可见性设置为false&#xff0c;以隐藏它们。 例如&#xff1a; <se…

笔记61:注意力提示

本地笔记地址&#xff1a;D:\work_file\&#xff08;4&#xff09;DeepLearning_Learning\03_个人笔记\3.循环神经网络\第10章&#xff1a;动手学深度学习~注意力机制 a a a a a a a a

MySQL索引使用总结

索引(index) 官方定义&#xff1a;一种提高MySQL查询效率的数据结构 优点&#xff1a;加快查询速度 缺点&#xff1a; 1.维护索引需要消耗数据库资源 2.索引需要占用磁盘空间 3.增删改的时候会影响性能 索引分类 索引和数据库表的存储引擎有关&#xff0c;不同的存储引擎&am…

springboot中如何用stream流的方式把mysql取出来的值给实体类中的多个字段赋值代码实例?

在 Spring Boot 中使用 Stream 流的方式将从 MySQL 数据库取出的值赋给实体类中的多个字段&#xff0c;你可以结合使用 JDBC&#xff08;Java Database Connectivity&#xff09;和 Stream API 来实现。以下是一个示例代码&#xff1a; import org.springframework.boot.Sprin…