卷积神经网络——LeNet——FashionMNIST

目录

  • 一、文件结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

一、文件结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)self.sig = nn.Sigmoid()self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)self.flatten = nn.Flatten()self.f5 = nn.Linear(in_features=5*5*16,out_features=120)self.f6 = nn.Linear(in_features=120,out_features=84)self.f7 = nn.Linear(in_features=84,out_features=10)def forward(self,x):x = self.sig(self.c1(x))x = self.s2(x)x = self.sig(self.c3(x))x = self.s4(x)x = self.flatten(x)x = self.f5(x)x = self.f6(x)x = self.f7(x)return x# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条# 定义数据加载和处理函数
def train_val_data_process():# 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensortrain_data = FashionMNIST(root="./data",train=True,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)# 将加载的数据集分为80%的训练数据和20%的验证数据train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])# 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)# 返回训练和验证的DataLoaderreturn train_dataloader, val_dataloader# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设置使用CUDA如果可用device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 打印使用的设备dev = "cuda" if torch.cuda.is_available() else "cpu"print(f'当前模型训练设备为: {dev}')# 初始化Adam优化器和交叉熵损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()# 将模型移动到选定的设备上model = model.to(device)# 复制模型权重用于后续更新最佳模型best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0  # 初始化最佳准确度# 初始化用于记录训练和验证过程中损失和准确度的列表train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []# 记录训练开始时间start_time = time.time()# 迭代指定的训练轮数for epoch in range(1, num_epochs + 1):# 记录每个epoch开始的时间since = time.time()# 打印分隔符和当前epoch信息print("-" * 10)print(f"Epoch: {epoch}/{num_epochs}")# 初始化训练和验证过程中的损失和正确预测数量train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 初始化批次计数器train_num = 0val_num = 0# 创建训练进度条progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')# 训练数据集的遍历for step, (b_x, b_y) in enumerate(train_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 训练模型model.train()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 清空梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 累加损失和正确预测数量train_loss += loss.item() * b_x.size(0)train_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器train_num += b_x.size(0)# 更新训练进度条progress_train_bar.update(1)# 关闭训练进度条progress_train_bar.close()# 创建验证进度条progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')# 验证数据集的遍历for step, (b_x, b_y) in enumerate(val_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 评估模型model.eval()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 累加损失和正确预测数量val_loss += loss.item() * b_x.size(0)val_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器val_num += b_x.size(0)# 更新验证进度条progress_val_bar.update(1)# 关闭验证进度条progress_val_bar.close()# 计算并记录epoch的平均损失和准确度train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double().item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double().item() / val_num)# 打印训练和验证的损失与准确度print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')# 计算并打印epoch训练耗费的时间time_use = time.time() - sinceprint(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')# 若当前epoch的验证准确度为最佳,则更新最佳模型权重if val_acc_all[-1] > best_acc:best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练结束,保存最佳模型权重torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')# 如果当前epoch为总epoch数,则保存最终模型权重if epoch == num_epochs:torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')# 将训练过程中的统计数据整理成DataFrametrain_process = pd.DataFrame(data={"epoch": range(1, num_epochs + 1),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})# 打印总训练时间consume_time = time.time() - start_timeprint(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')# 返回包含训练过程统计数据的DataFramereturn train_process# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):# 创建图形和子图plt.figure(figsize=(12, 4))# 绘制训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")# 保存损失图像plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')# 绘制训练和验证准确度plt.subplot(1, 2, 2)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")plt.legend()plt.xlabel("epoch")plt.ylabel("accuracy")# 保存准确率曲线图plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')plt.show()if __name__ == "__main__":model = LeNet()train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表testdef t_data_process():test_data = FashionMNIST(root="./data",train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef t_model_process(model, test_dataloader):if model is not None:print('Successfully loaded the model.')device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0all_preds = []  # 存储所有预测标签all_labels = []  # 存储所有实际标签# 只进行前向传播,不计算梯度with torch.no_grad():for test_x, test_y in test_dataloader:test_x = test_x.to(device)test_y = test_y.to(device)# 设置模型为验证模式model.eval()# 前向传播得到一个batch的结果output = model(test_x)# 查找最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 收集预测和实际标签all_preds.extend(pre_lab.tolist())all_labels.extend(test_y.tolist())# 计算准确率test_corrects += torch.sum(pre_lab == test_y.data)# 将所有的测试样本进行累加test_num += test_x.size(0)# 计算准确率test_acc = test_corrects.double().item() / test_numprint(f'测试的准确率:{test_acc}')# 绘制混淆矩阵conf_matrix = confusion_matrix(all_labels, all_preds)sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')plt.xlabel('Predicted labels')plt.ylabel('True labels')plt.title('Confusion Matrix')plt.show()plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')if __name__=="__main__":# 加载模型model = LeNet()print('loading model')# 加载权重model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))# 加载测试数据test_dataloader = t_data_process()# 加载模型测试的函数t_model_process(model,test_dataloader)device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)model.eval()output = model(b_x)pre_lab = torch.argmax(output,dim=1)result = pre_lab.item()label = b_y.item()print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45831.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Autosar Dcm配置-0x28服务ComControl-基于ETAS软件

文章目录 前言DcmDcmDsdDcmDspBswMBswMModeRequestPortBswMModeConditionBswMLogicalExpressionBswMActionBswMActionListBswMRule总结前言 0x28服务主要用来控制非诊断报文的通讯,一般在刷写预编程过程中,用来禁止APP的通信报文,可以减少总线负载率,提高刷写成功率。本文…

[C++] STL :stackqueue详解 及 模拟实现

标题:[C] STL :stack&&queue详解 水墨不写bug 目录 (一)stack简介 (二)queue简介 (三)容器适配器 (四)stack和queue的模拟实现 /*** …

数据结构(初阶1.复杂度)

文章目录 一、复杂度概念 二、时间复杂度 2.1 大O的渐进表示法 2.2 时间复杂度计算示例 2.2.1. // 计算Func2的时间复杂度? 2.2.2.// 计算Func3的时间复杂度? 2.2.3.// 计算Func4的时间复杂度? 2.2.4.// 计算strchr的时间复杂度? …

构造者模式的实现

引言——构造复杂对象的艺术 软件工程中,构造复杂对象的艺术被巧妙地封装在构造者模式(Builder Pattern)中。这种设计模式不仅提供了一种清晰且灵活的方式来构建复杂对象,还使得代码更具可读性和可维护性。构造者模式的核心思想是…

Unity最新第三方开源插件《Stateful Component》管理中大型项目MonoBehaviour各种序列化字段 ,的高级解决方案

上文提到了UIState, ObjectRefactor等,还提到了远古的NGUI, KBEngine-UI等 这个算是比较新的解决方法吧,但是抽象出来,问题还是这些个问题 所以你就说做游戏是不是先要解决这些问题? 而不是高大上的UiImage,DoozyUI等 Mono管理引用基本用法 ① 添加Stateful Component …

安全测试理论

安全测试理论 什么是安全测试? 安全测试:发现系统安全隐患的过程安全测试与传统测试区别 传统测试:发现bug为目的 安全测试:发现系统安全隐患什么是渗透测试 渗透测试:已成功入侵系统为目标的的攻击过程渗透测试与安全…

“好物”推荐+Xshell连接实例+使用Conda创建独立的Python环境

目录 主题:好易智算平台推荐RTX 4090DGPU实例租用演示安装配置torch1.9.1cuda11.1.1环境引言:算力的新时代平台介绍:技术与信任的结晶使用案例:实际使用展示创建实例开始使用连接实例(下文演示使用Xshell连接&#xff…

昇思25天学习打卡营第二十天|基于MobileNetv2的垃圾分类

打卡营第二十天,今天学习的内容是MobileNet垃圾分类,记录一下学习内容: 学习内容 本文档主要介绍垃圾分类代码开发的方法。通过读取本地图像数据作为输入,对图像中的垃圾物体进行检测,并且将检测结果图片保存到文件中…

【ARM】CCI集成指导整理

目录 1.CCI集成流程 2.CCI功能集成指导 2.1CCI结构框图解释 Request concentrator Transaction tracker Read-data Network Write-data Network B-response Network 2.2 接口注意项 记录一下CCI500的ACE slave interface不支持的功能: 对于ACE-Lite slav…

基于信号处理的PPG信号滤波降噪方法(MATLAB)

光电容积脉搏波PPG信号结合相关算法可以用于人体生理参数检测,如血压、血氧饱和度等,但采集过程中极易受到噪声干扰,对于血压、血氧饱和度测量的准确性造成影响。随着当今社会医疗保健技术的发展,可穿戴监测设备对于PPG信号的质量…

简单的SQL字符型注入

目录 注入类型 判断字段数 确定回显点 查找数据库名 查找数据库表名 查询字段名 获取想要的数据 以sqli-labs靶场上的简单SQL注入为例 注入类型 判断是数字类型还是字符类型 常见的闭合方式 ?id1、?id1"、?id1)、?id1")等,大多都是单引号…

【ASTGCN】模型调试学习笔记--数据生成详解(超详细)

利用滑动窗口生成时间序列 原理图示: 以PEMS04数据集为例。 该数据集维度为:(16992,307,3),16992表示时间序列的长度,307为探测器个数,即图的顶点个数,3为特征数,即流量,速度、平…

期权专题12:期权保证金和期权盈亏

目录 1. 期权保证金 1.1 计算逻辑 1.2 代码复现 1.3 实际案例 2. 期权盈亏 2.1 价格走势 2.2 计算公式 2.2.1 卖出期权 2.2.2 买入期权 免责声明:本文由作者参考相关资料,并结合自身实践和思考独立完成,对全文内容的准确性、完整性或…

[CISCN 2023 华北]normal_snake

[CISCN 2023 华北]normal_snake 源码和依赖 算了直接说吧,不想截图了,就多了一个C3P0和yaml的依赖 然后read路由可以反序列化yaml的Str 我们看到waf 那个String是可以二次反序列化绕过的,然后CUSTOM_STRING1解码后是"BadAttributeValuePairExcept…

【java】力扣 反转链表

力扣 206 链表反转 题目介绍 解法讲解 先定义两个游标indexnull,prenull,反转之后链表应该是5,4,3,2,1,我们先进行2->1的反转,然后再循坏即可 让定义的游标index去存储head.n…

MySQL设置白名单限制

白名单(Whitelist)是一种机制,用于限制哪些主机可以连接到服务器,而阻止其他主机的访问。通过配置白名单,可以增加服务器的安全性,防止未授权的访问。 在MySQL数据库中直接设置白名单访问(即限制…

【触摸屏】【地震知识宣传系统】功能模块:视频 + 知识问答

项目背景 鉴于地震知识的普及对于提升公众防灾减灾意识的重要性,客户希望开发一套互动性强、易于理解的地震学习系统,面向公众、学生及专业人员进行地震知识教育与应急技能培训。 产品功能 系统风格:严谨的设计风格和准确的信息呈现&#…

红酒的艺术之旅:品味、鉴赏与生活的整合

在繁忙的都市生活中,红酒如同一道不同的风景线,将品味、鉴赏与日常生活巧妙地整合在一起。它不仅仅是一种饮品,更是一种艺术,一种生活的态度。今天,就让我们一起踏上这趟红酒的艺术之旅,探寻雷盛红酒如何以…

【qt】如何读取文件并拆分信息?

需要用到QTextStream类 还有QFile类 对于文件的读取操作我们可以统一记下如下操作: 就这三板斧 获取到文件名用文件名初始化文件对象用文件对象初始化文本流 接下来就是打开文件了 用open()来打开文件 用readLine()来读取行数据 用atEnd()来判断是否读到结尾 用split()来获取…

02. Hibernate 初体验之持久化对象

1. 前言 本节课程让我们一起体验 Hibernate 的魅力!编写第一个基于 Hibernate 的实例程序。 在本节课程中,你将学到 : Hibernate 的版本发展史;持久化对象的特点。 为了更好地讲解这个内容,这个初体验案例分上下 2…