卷积神经网络——LeNet——FashionMNIST

目录

  • 一、文件结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

一、文件结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)self.sig = nn.Sigmoid()self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)self.flatten = nn.Flatten()self.f5 = nn.Linear(in_features=5*5*16,out_features=120)self.f6 = nn.Linear(in_features=120,out_features=84)self.f7 = nn.Linear(in_features=84,out_features=10)def forward(self,x):x = self.sig(self.c1(x))x = self.s2(x)x = self.sig(self.c3(x))x = self.s4(x)x = self.flatten(x)x = self.f5(x)x = self.f6(x)x = self.f7(x)return x# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条# 定义数据加载和处理函数
def train_val_data_process():# 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensortrain_data = FashionMNIST(root="./data",train=True,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)# 将加载的数据集分为80%的训练数据和20%的验证数据train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])# 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)# 返回训练和验证的DataLoaderreturn train_dataloader, val_dataloader# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设置使用CUDA如果可用device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 打印使用的设备dev = "cuda" if torch.cuda.is_available() else "cpu"print(f'当前模型训练设备为: {dev}')# 初始化Adam优化器和交叉熵损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()# 将模型移动到选定的设备上model = model.to(device)# 复制模型权重用于后续更新最佳模型best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0  # 初始化最佳准确度# 初始化用于记录训练和验证过程中损失和准确度的列表train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []# 记录训练开始时间start_time = time.time()# 迭代指定的训练轮数for epoch in range(1, num_epochs + 1):# 记录每个epoch开始的时间since = time.time()# 打印分隔符和当前epoch信息print("-" * 10)print(f"Epoch: {epoch}/{num_epochs}")# 初始化训练和验证过程中的损失和正确预测数量train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 初始化批次计数器train_num = 0val_num = 0# 创建训练进度条progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')# 训练数据集的遍历for step, (b_x, b_y) in enumerate(train_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 训练模型model.train()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 清空梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 累加损失和正确预测数量train_loss += loss.item() * b_x.size(0)train_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器train_num += b_x.size(0)# 更新训练进度条progress_train_bar.update(1)# 关闭训练进度条progress_train_bar.close()# 创建验证进度条progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')# 验证数据集的遍历for step, (b_x, b_y) in enumerate(val_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 评估模型model.eval()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 累加损失和正确预测数量val_loss += loss.item() * b_x.size(0)val_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器val_num += b_x.size(0)# 更新验证进度条progress_val_bar.update(1)# 关闭验证进度条progress_val_bar.close()# 计算并记录epoch的平均损失和准确度train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double().item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double().item() / val_num)# 打印训练和验证的损失与准确度print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')# 计算并打印epoch训练耗费的时间time_use = time.time() - sinceprint(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')# 若当前epoch的验证准确度为最佳,则更新最佳模型权重if val_acc_all[-1] > best_acc:best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练结束,保存最佳模型权重torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')# 如果当前epoch为总epoch数,则保存最终模型权重if epoch == num_epochs:torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')# 将训练过程中的统计数据整理成DataFrametrain_process = pd.DataFrame(data={"epoch": range(1, num_epochs + 1),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})# 打印总训练时间consume_time = time.time() - start_timeprint(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')# 返回包含训练过程统计数据的DataFramereturn train_process# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):# 创建图形和子图plt.figure(figsize=(12, 4))# 绘制训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")# 保存损失图像plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')# 绘制训练和验证准确度plt.subplot(1, 2, 2)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")plt.legend()plt.xlabel("epoch")plt.ylabel("accuracy")# 保存准确率曲线图plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')plt.show()if __name__ == "__main__":model = LeNet()train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表testdef t_data_process():test_data = FashionMNIST(root="./data",train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef t_model_process(model, test_dataloader):if model is not None:print('Successfully loaded the model.')device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0all_preds = []  # 存储所有预测标签all_labels = []  # 存储所有实际标签# 只进行前向传播,不计算梯度with torch.no_grad():for test_x, test_y in test_dataloader:test_x = test_x.to(device)test_y = test_y.to(device)# 设置模型为验证模式model.eval()# 前向传播得到一个batch的结果output = model(test_x)# 查找最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 收集预测和实际标签all_preds.extend(pre_lab.tolist())all_labels.extend(test_y.tolist())# 计算准确率test_corrects += torch.sum(pre_lab == test_y.data)# 将所有的测试样本进行累加test_num += test_x.size(0)# 计算准确率test_acc = test_corrects.double().item() / test_numprint(f'测试的准确率:{test_acc}')# 绘制混淆矩阵conf_matrix = confusion_matrix(all_labels, all_preds)sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')plt.xlabel('Predicted labels')plt.ylabel('True labels')plt.title('Confusion Matrix')plt.show()plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')if __name__=="__main__":# 加载模型model = LeNet()print('loading model')# 加载权重model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))# 加载测试数据test_dataloader = t_data_process()# 加载模型测试的函数t_model_process(model,test_dataloader)device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)model.eval()output = model(b_x)pre_lab = torch.argmax(output,dim=1)result = pre_lab.item()label = b_y.item()print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45831.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Autosar Dcm配置-0x28服务ComControl-基于ETAS软件

文章目录 前言DcmDcmDsdDcmDspBswMBswMModeRequestPortBswMModeConditionBswMLogicalExpressionBswMActionBswMActionListBswMRule总结前言 0x28服务主要用来控制非诊断报文的通讯,一般在刷写预编程过程中,用来禁止APP的通信报文,可以减少总线负载率,提高刷写成功率。本文…

[C++] STL :stackqueue详解 及 模拟实现

标题:[C] STL :stack&&queue详解 水墨不写bug 目录 (一)stack简介 (二)queue简介 (三)容器适配器 (四)stack和queue的模拟实现 /*** …

liunx作业笔记1

一、选择题(每小题2分,共20分) 1、下列变量命名为Shell中无效变量名的是( D ) A、v_ar1 B、var1 C、_var D、*var 变量名以字母开头,包含下划线和数字。 2、关于expr命令的使用下列命令中得数不等于…

论文创新点总结

TOC 双分支模型 今天读的这篇文章中提到了一种以前没有接触的模型,这个模型使用了双分支的网络来处理图像增强的问题(将图像增强问题分解为亮度调整和色度恢复两个子问题),其中一个分支为亮度调整网络(LAN&#xff0…

数据结构(初阶1.复杂度)

文章目录 一、复杂度概念 二、时间复杂度 2.1 大O的渐进表示法 2.2 时间复杂度计算示例 2.2.1. // 计算Func2的时间复杂度? 2.2.2.// 计算Func3的时间复杂度? 2.2.3.// 计算Func4的时间复杂度? 2.2.4.// 计算strchr的时间复杂度? …

构造者模式的实现

引言——构造复杂对象的艺术 软件工程中,构造复杂对象的艺术被巧妙地封装在构造者模式(Builder Pattern)中。这种设计模式不仅提供了一种清晰且灵活的方式来构建复杂对象,还使得代码更具可读性和可维护性。构造者模式的核心思想是…

Unity最新第三方开源插件《Stateful Component》管理中大型项目MonoBehaviour各种序列化字段 ,的高级解决方案

上文提到了UIState, ObjectRefactor等,还提到了远古的NGUI, KBEngine-UI等 这个算是比较新的解决方法吧,但是抽象出来,问题还是这些个问题 所以你就说做游戏是不是先要解决这些问题? 而不是高大上的UiImage,DoozyUI等 Mono管理引用基本用法 ① 添加Stateful Component …

python的变量与赋值

变量 定义:变量是内存中存储数据的标识符。命名规则:变量名可以包含字母、数字、下划线,但不能以数字开头。例如:name, _var, var2 都是合法的变量名,但 2var 是不合法的。 赋值 操作符:使用 符号进行赋…

安全测试理论

安全测试理论 什么是安全测试? 安全测试:发现系统安全隐患的过程安全测试与传统测试区别 传统测试:发现bug为目的 安全测试:发现系统安全隐患什么是渗透测试 渗透测试:已成功入侵系统为目标的的攻击过程渗透测试与安全…

ES6 Generator函数的异步应用 (八)

ES6 Generator 函数的异步应用主要通过与 Promise 配合使用来实现。这种模式被称为 “thunk” 模式,它允许你编写看起来是同步的异步代码。 特性: 暂停执行:当 Generator 函数遇到 yield 表达式时,它会暂停执行,等待 …

“好物”推荐+Xshell连接实例+使用Conda创建独立的Python环境

目录 主题:好易智算平台推荐RTX 4090DGPU实例租用演示安装配置torch1.9.1cuda11.1.1环境引言:算力的新时代平台介绍:技术与信任的结晶使用案例:实际使用展示创建实例开始使用连接实例(下文演示使用Xshell连接&#xff…

Android Studio下载与安装

Android Studio下载与安装_android studio下载安装-CSDN博客

昇思25天学习打卡营第二十天|基于MobileNetv2的垃圾分类

打卡营第二十天,今天学习的内容是MobileNet垃圾分类,记录一下学习内容: 学习内容 本文档主要介绍垃圾分类代码开发的方法。通过读取本地图像数据作为输入,对图像中的垃圾物体进行检测,并且将检测结果图片保存到文件中…

Jupyter Notebook 使用教程

Jupyter Notebook 使用教程 目录 概述启动Jupyter Notebook创建新的NotebookNotebook界面介绍使用代码单元格使用Markdown单元格Notebook的基本操作保存和导出Notebook扩展功能和技巧 1. 概述 Jupyter Notebook是一个开源的Web应用程序,允许您创建和共享包含代码…

VMM、VMI、VIM的简介

VMM 指的是虚拟机监控器(Virtual Machine Monitor),也被称为虚拟化管理程序(Hypervisor)。VMM 是一种软件、固件或硬件的组合,它在物理硬件和虚拟机之间充当中介。其主要功能是创建、管理和运行虚拟机&…

【ARM】CCI集成指导整理

目录 1.CCI集成流程 2.CCI功能集成指导 2.1CCI结构框图解释 Request concentrator Transaction tracker Read-data Network Write-data Network B-response Network 2.2 接口注意项 记录一下CCI500的ACE slave interface不支持的功能: 对于ACE-Lite slav…

基于信号处理的PPG信号滤波降噪方法(MATLAB)

光电容积脉搏波PPG信号结合相关算法可以用于人体生理参数检测,如血压、血氧饱和度等,但采集过程中极易受到噪声干扰,对于血压、血氧饱和度测量的准确性造成影响。随着当今社会医疗保健技术的发展,可穿戴监测设备对于PPG信号的质量…

简单的SQL字符型注入

目录 注入类型 判断字段数 确定回显点 查找数据库名 查找数据库表名 查询字段名 获取想要的数据 以sqli-labs靶场上的简单SQL注入为例 注入类型 判断是数字类型还是字符类型 常见的闭合方式 ?id1、?id1"、?id1)、?id1")等,大多都是单引号…

【ASTGCN】模型调试学习笔记--数据生成详解(超详细)

利用滑动窗口生成时间序列 原理图示: 以PEMS04数据集为例。 该数据集维度为:(16992,307,3),16992表示时间序列的长度,307为探测器个数,即图的顶点个数,3为特征数,即流量,速度、平…

期权专题12:期权保证金和期权盈亏

目录 1. 期权保证金 1.1 计算逻辑 1.2 代码复现 1.3 实际案例 2. 期权盈亏 2.1 价格走势 2.2 计算公式 2.2.1 卖出期权 2.2.2 买入期权 免责声明:本文由作者参考相关资料,并结合自身实践和思考独立完成,对全文内容的准确性、完整性或…