知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

dataloader_tools.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoaderdef load_data():# 载入MNIST训练集train_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=True,transform=transforms.ToTensor(),download=True)# 载入MNIST测试集test_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=False,transform=transforms.ToTensor(),download=True)# 生成训练集和测试集的dataloadertrain_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)return train_dataloader,test_dataloader

models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,1200)self.fc2 = nn.Linear(1200,1200)self.fc3 = nn.Linear(1200,num_classes)self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return xclass StudentModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,20)self.fc2 = nn.Linear(20,20)self.fc3 = nn.Linear(20,num_classes)def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as Fdef train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device):# ----------------------开始计时-----------------------------------start_time = time.time()# 设置参数开始训练best_acc, best_epoch = 0, 0criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 训练集上训练模型权重for data, targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 前向传播preds = model(data)loss = criterion(preds, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试集上评估模型性能model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()if acc > best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc),f'loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},最优参数已经保存到:weights/{model_name}_best_acc_params.pth')# -------------------------结束计时------------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device):# -------------------------------------开始计时--------------------------------start_time = time.time()# 定以损失函数hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction="batchmean")# 定义优化器optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)best_acc,best_epoch = 0,0for epoch in range(epochs):student_model.train()# 训练集上训练模型权重for data,targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)# 计算hard_lossstudent_hard_loss = hard_loss(student_preds,targets)# 计算蒸馏后的预测结果及soft_lossditillation_loss = soft_loss(F.softmax(student_preds/temp,dim=1),F.softmax(teacher_preds/temp,dim=1))# 将hard_loss和soft_loss加权求和loss = temp * temp * alpha * student_hard_loss + (1-alpha)*ditillation_loss# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()#测试集上评估模型性能student_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x,y in test_dataloader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions ==y).sum()num_samples += predictions.size(0)acc = (num_correct/num_samples).item()if acc>best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")student_model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))print(f'student_hard_loss={student_hard_loss},ditillation_loss={ditillation_loss},loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},')# --------------------------------结束计时----------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')

训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import dataloader_tools
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = dataloader_tools.load_data()# 定义教师模型
model = models.TeacherModel()
model = model.to(device)
# 打印模型的参数
summary(model)# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'teacher'
train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,值为:0.9868999719619751

用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import dataloader_tools
import models
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = dataloader_tools.load_data()# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

用知识蒸馏的方法训练student model

import torch
import train_tools
import models
import dataloader_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 加载数据
train_dataloader,test_dataloader = dataloader_tools.load_data()# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()# 开始训练
lr = 0.0001
epochs = 20
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
model_name = 'distill_student_loss'
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device)
最优准确率的epoch为9,值为:0.9204999804496765,
训练用时为:2.14minutes

在这里插入图片描述

loss改为:

# temp的平方乘在student_hard_loss
loss = temp * temp * alpha * student_hard_loss + (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9336999654769897,
训练用时为:2.12minutes

loss改为:

# temp的平方乘ditillation_loss
loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9176999926567078,
训练用时为:2.09minutes

上面的几种loss,蒸馏损失都出现了负数的情况。不太对劲。
在这里插入图片描述

其它开源的知识蒸馏算法如下:

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT Day01 qt概述,创建项目,窗口属性,按钮,信号与槽

1.qt概述 1.什么是qt Qt 是一个跨平台的 C 图形用户界面应用程序框架。它为应用程序开发者提供建立艺 术级图形界面所需的所有功能。它是完全面向对象的&#xff0c;很容易扩展&#xff0c;并且允许真正的组 件编程。 2.支持的平台 Windows – XP 、 Vista 、 Win7 、 Win8…

Java(119):ExcelUtil工具类(org.apache.poi读取和写入Excel)

ExcelUtil工具类(XSSFWorkbook读取和写入Excel),入参和出参都是:List<Map<String,Object>> 一、读取Excel testdata.xlsx 1、new XSSFWorkbook对象 File file = new File(filePath); FileInputStream fis = new FileInputStream(file);…

8.二维数组——将一个二维数组行和列的元素互换,存到另一个二维数组中。

文章目录 前言一、题目描述 二、题目分析 三、解题 程序运行代码 前言 本系列为二维数组编程题&#xff0c;点滴成长&#xff0c;一起逆袭。 一、题目描述 将一个二维数组行和列的元素互换&#xff0c;存到另一个二维数组中。 二、题目分析 三、解题 程序运行代码 #incl…

玄学调参实践篇 | 深度学习模型 + 预训练模型 + 大模型LLM

&#x1f60d; 这篇主要简单记录一些调参实践&#xff0c;无聊时会不定期更新~ 文章目录 0、学习率与batch_size判断1、Epoch数判断2、判断模型架构是否有问题3、大模型 - 计算量、模型、和数据大小的关系4、大模型调参相关论文经验总结5、训练时模型的保存 0、学习率与batch_s…

压力测试+接口测试

jmeter是apache公司基于java开发的一款开源压力测试工具&#xff0c;体积小&#xff0c;功能全&#xff0c;使用方便&#xff0c;是一个比较轻量级的测试工具&#xff0c;使用起来非常简单。因 为jmeter是java开发的&#xff0c;所以运行的时候必须先要安装jdk才可以。jmeter是…

鸿蒙系统开发手册 - HarmonyOS内核驱动层源码分析

众所周知系统定义HarmonyOS是一款“面向未来”、面向全场景&#xff08;移动办公、运动健康、社交通信、媒体娱乐等&#xff09;的分布式操作系统。在传统的单设备系统能力的基础上&#xff0c;HarmonyOS提出了基于同一套系统能力、适配多种终端形态的分布式理念&#xff0c;能…

Arrays.asList() 与 Collections.singletonList()的恩怨情仇

1. 概述 列表是我们使用 Java 时常用的集合类型。 众所周知&#xff0c;我们可以轻松地用一行初始化一个List。例如&#xff0c;当我们想要初始化一个只有一个元素的List时&#xff0c;我们可以使用Arrays.asList()方法或Collections.singletonList()方法。 在本文中&#x…

【Linux】基础IO--文件基础知识/文件操作/文件描述符

文章目录 一、文件相关基础知识二、文件操作1.C语言文件操作2.操作系统文件操作2.1 比特位传递选项2.2 文件相关系统调用2.3 文件操作接口的使用 三、文件描述符fd1.什么是文件描述符2.文件描述符的分配规则 一、文件相关基础知识 我们对文件有如下的认识&#xff1a; 1.文件 …

用最少数量的箭引爆气球[中等]

优质博文&#xff1a;IT-BLOG-CN 一、题目 有一些球形气球贴在一堵用XY平面表示的墙面上。墙面上的气球记录在整数数组points&#xff0c;其中points[i] [xstart, xend]表示水平直径在xstart和xend之间的气球。你不知道气球的确切y坐标。一支弓箭可以沿着x轴从不同点完全垂直…

Panalog 日志审计系统 前台RCE漏洞复现

0x01 产品简介 Panalog是一款日志审计系统&#xff0c;方便用户统一集中监控、管理在网的海量设备。 0x02 漏洞概述 Panalog日志审计系统 sy_query.php接口处存在远程命令执行漏洞&#xff0c;攻击者可执行任意命令&#xff0c;接管服务器权限。 0x03 复现环境 FOFA&#xf…

谭巍主任专业角度解读:疣体脱落前的症状是什么?

我们时常会发现身体各个部位长出一些赘生物&#xff0c;有些属于皮肤良性改变&#xff0c;而有些则是病毒引起的&#xff0c;称之为疣体。然而在疣体脱落之前&#xff0c;通常会出现一些症状&#xff0c;这些症状可能因人而异&#xff0c;但以下是一些常见的迹象&#xff1a; 1…

笔记61:注意力提示

本地笔记地址&#xff1a;D:\work_file\&#xff08;4&#xff09;DeepLearning_Learning\03_个人笔记\3.循环神经网络\第10章&#xff1a;动手学深度学习~注意力机制 a a a a a a a a

MySQL索引使用总结

索引(index) 官方定义&#xff1a;一种提高MySQL查询效率的数据结构 优点&#xff1a;加快查询速度 缺点&#xff1a; 1.维护索引需要消耗数据库资源 2.索引需要占用磁盘空间 3.增删改的时候会影响性能 索引分类 索引和数据库表的存储引擎有关&#xff0c;不同的存储引擎&am…

AndroidStudio - 新版本 Logcat 使用详解

最近这俩天正好有时间给自己做一下减法&#xff0c;忘记是去年还是今年&#xff0c;在升级 AndroidStudio 后使用 Logcat查看日志的方式也发生了一些变化&#xff0c;虽然一直在使用&#xff0c;但每当看到之前还未关闭 Logcat 命令行工具额昂也&#xff0c;就感觉可能还存在知…

Multi-head attention机制

多头&#xff1a;多个相同结构的线性变换层&#xff08;方阵&#xff09;&#xff0c;要求分别线性变换 B站教学视频参考&#xff1a;https://www.bilibili.com/video/BV1eG4y1N7Jp/?p17&spm_id_frompageDriver&vd_sourcef4c7dcac0ad5ae8189bd414a3b23020d 什么是多头…

冒泡排序算法是对已知的数列进行从小到大的递增排序。

题目描述冒泡排序算法是对已知的数列进行从小到大的递增排序每个实例输出两行&#xff0c;第一行输出第1轮结果, 第二行输出最终结果 它的排序方法如下: 1.对数列从头开始扫描&#xff0c;比较两个相邻的元素,如果前者大于后者,则交换两者位置 2.重复步骤1&#xff0c;直到没有…

RocketMQ源码剖析之createUniqID方法

目录 版本信息&#xff1a; 写在前面&#xff1a; 源码剖析&#xff1a; 总计&#xff1a; 版本信息&#xff1a; RocketMQ-5.1.3 源码地址&#xff1a;https://github.com/apache/rocketmq 写在前面&#xff1a; 首先&#xff0c;笔者先吐槽一下RocketMQ的官方&#xff0…

attention中Q,K,V的理解

第一种 1.首先定义三个线性变换矩阵&#xff0c;query&#xff0c;key&#xff0c;value&#xff1a; class BertSelfAttention(nn.Module):self.query nn.Linear(config.hidden_size, self.all_head_size) # 输入768&#xff0c; 输出768self.key nn.Linear(config.hidde…

上海线下活动 | LLM 时代的 AI 编译器实践与创新

今年 3 月份&#xff0c; 2023 Meet TVM 系列首次线下活动从上海出发&#xff0c;跨越多个城市&#xff0c;致力于为各地关注 AI 编译器的工程师提供一个学习、交流的平台。 12 月 16 日 2023 Meet TVM 年终聚会将重返上海&#xff0c;这一次我们不仅邀请了 4 位资深的 AI 编…

自动伸缩:解密HPA、VPA、CA和CPA智能调整应用大小和数量

关注【云原生百宝箱】公众号&#xff0c;快速掌握云原生 Kubernetes提供了多种自动伸缩机制&#xff0c;例如HPA&#xff08;Horizontal Pod Autoscaling&#xff09;&#xff0c;可以根据不同情况动态调整Pod副本数量。此功能使 Pod 能够有效地处理当前流量&#xff0c;而无需…