LeNet5实战——衣服分类

  • 搭建模型
  • 训练代码(数据处理、模型训练、性能指标)——> 产生权重w ——>模型结构c、w
  • 测试

配置环境

Pycharm刚配置的环境找不到了-CSDN博客

model.py

导入库

import torch  
from torch import nn  
from torchsummary import summary

模型搭建

 note:

  • stride 步幅为1,和默认值一样,不用写
  • padding=0,和默认一样不用写

代码

import torch  
from torch import nn  
from torchsummary import summary  class LeNet(nn.Module):  #初始化  def __init__(self):  super(LeNet,self).__init__()  self.c1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)  self.sig=nn.Sigmoid()  self.s2=nn.AvgPool2d(kernel_size=2,stride=2)  self.c3=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)  self.s4=nn.AvgPool2d(kernel_size=2,stride=2)  self.flatten=nn.Flatten()  self.f5 = nn.Linear(in_features=400,out_features=120)  self.f6 = nn.Linear(in_features=120, out_features=84)  self.f7 = nn.Linear(in_features=84, out_features=10)  def forward(self,x):  x = self.sig(self.c1(x))#经过卷积和激活  x=self.s2(x)  x=self.sig(self.c3(x))  x=self.s4(x)  x=self.flatten(x)  x=self.f5(x)  x=self.f6(x)  x=self.f7(x)  return x  if __name__=="__main__":  device = torch.device("cuda" if torch.cuda.is_available()else "cpu")  print(device)  model = LeNet().to(device)#实例化  print(summary(model,input_size=(1,28,28)))

前向传播结果

plot.py

模型加载

下载数据集

打包数据

为什么要移除一维? 

因为之前将数据打包成64一组,数据格式为64 *28 * 28 * 1,把64移除,剩下的28* 28 * 1就是图片格式

 获取图片数据

 可视化数据(图片)

代码

from torchvision.datasets import FashionMNIST  
from torchvision import transforms#处理数据集  
import torch.utils.data as Data  
import numpy as np  
import matplotlib.pyplot as plt  
from model import LeNet # 导入模型(没有训练的模型)  def train_val_data_process():  train_data = FashionMNIST(root='./data',  train=True,  transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),  # 转换成张量形式方便应用  download=True)  train_data,val_data = Data.random_split(train_data,lengths=(round(0.8*len(train_data)),round(0.2*len(train_data))))#随机划分数据  train_dateloader = Data.DataLoader(dataset=train_data,  batch_size=128,  shuffle=True,  num_workers=8)#进程  val_dateloader = Data.DataLoader(dataset=val_data,  batch_size=128,  shuffle=True,  num_workers=8)  return train_dateloader,val_dateloader  

可视化结果

一批次的图片(64张)

model_train.py

导入库

import copy  
import time  import torch  
from torchvision.datasets import FashionMNIST  
from torchvision import transforms  # 处理数据集  
import torch.utils.data as Data  
import numpy as np  
import matplotlib.pyplot as plt  
from model import LeNet  # 导入模型(没有训练的模型)  
import torch.nn as nn  
import pandas as pd
  • FashionMNIST数据集由Zalando研究团队创建,包含了10个不同类别的灰度图像。每个图像的尺寸为28x28像素,共有训练集和测试集两部分。(衣服分类数据集)
  • transforms模块提供了一种方便的方式来对图像数据进行常见的预处理操作,如缩放、裁剪、旋转、翻转、标准化等。它还可以用于将图像数据转换为张量(Tensor)格式,并根据需要进行其他转换操作。
  • torch.utils.data是PyTorch中的一个模块,提供了用于数据加载和预处理的工具类和函数。它提供了一种方便的方式来处理和准备数据,以供机器学习模型的训练和评估使用。torch.utils.data模块中的两个重要类是DatasetDataLoader
  • torch.nn模块包含了许多常用的神经网络层类,提供了各种损失函数。
  • pandas是一个功能强大且灵活的数据处理和分析库,它提供了高性能、易于使用的数据结构和数据分析工具

train_val_data_process()

代码

def train_val_data_process():  train_data = FashionMNIST(root='./data',  train=True,  transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),  # 转换成张量形式方便应用  download=True)  train_data, val_data = Data.random_split(train_data, lengths=(  round(0.8 * len(train_data)), round(0.2 * len(train_data))))  # 随机划分数据  train_dataloader = Data.DataLoader(dataset=train_data,  batch_size=32,  shuffle=True,  num_workers=2)  # 进程  val_dataloader = Data.DataLoader(dataset=val_data,  shuffle=True,  num_workers=2)  return train_dataloader, val_dataloader

FashinMNIST

FashionMNIST是一个用于图像分类的数据集,包含了10个类别的服装图像。 指定root参数为'./data'train参数为Truetransform参数为一个transforms.Compose对象,以及download参数为True,可以下载并加载FashionMNIST数据集。

transforms.Compose对象是一个数据预处理的组合,这里使用了transforms.Resize将图像大小调整为28×28,并使用transforms.ToTensor将图像转换为张量形式。

Data.random_split

将train_data按照8|2的比例随机划分给train_data和val_data

Data.DataLoader

  • dataset:指定要加载的数据集,这里是train_data,即训练数据集。
  • batch_size:指定每个批次中的样本数量,这里是32,表示每次加载32个样本。
  • shuffle:指定是否在每个迭代周期前打乱数据顺序,这里设置为True,表示在每个迭代周期前打乱数据顺序。
  • num_workers:指定用于数据加载的线程数,这里设置为2,表示使用2个进程进行数据加载。

train_model_process

代码

def train_model_process(model, train_dataloader, val_dataloader, num_epochs):  # 设定训练所用到的设备,有GPU用GPU,没有则用CPU  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用Adam优化器,学习率为0.001(adam——优化的梯度下降法)  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 损失函数为交叉熵函数  criterion = nn.CrossEntropyLoss()  # 将模型放到训练设备中  model = model.to(device)  # 赋值当前模型的参数  best_model_wts = copy.deepcopy(model.state_dict())  # 初始化参数  # 最高精准度  best_acc = 0.0  # 训练集损失函数列表  train_loss_all = []  # 验证集损失函数列表  val_loss_all = []  # 训练集精度列表  train_acc_all = []  # 验证集精度列表  val_acc_all = []  # 当前时间  since = time.time()  for epoch in range(num_epochs):  print("Epoch {}/{}".format(epoch, num_epochs - 1))  print("-" * 10)  # 初始化参数  # 训练集损失函数  train_loss = 0.0  # 训练集准确度  train_corrects = 0  # 验证集损失函数  val_loss = 0.0  # 验证集准确度  val_corrects = 0  # 训练集样本数量  train_num = 0  # 验证集样本数量  val_num = 0  # 对每一个mini-batch训练和计算  for step, (b_x, b_y) in enumerate(train_dataloader):  # 将特征放入到训练设备中  b_x = b_x.to(device)  # 将标签放入到训练设备中  b_y = b_y.to(device)  # 设置模型为训练模式  model.train()  # 前向传播过程,输入为一个batch,输出为一个batch中对应的预测  output = model(b_x)  # 查找每一行中最大值对应的行标  pre_lab = torch.argmax(output, dim=1)  # 模型的输出和标签计算损失函数  loss = criterion(output, b_y)  # 将梯度初始化为0  optimizer.zero_grad()  # 反向传播计算  loss.backward()  # 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用  optimizer.step()  # 对损失函数进行累加  train_loss += loss.item() * b_x.size(0)  # 如果预测正确,则准确度train_corrects+1  train_corrects += torch.sum(pre_lab == b_y.data)  # 当前用于训练的样本数量  train_num += b_x.size(0)  for step, (b_x, b_y) in enumerate(val_dataloader):  b_x = b_x.to(device)  b_y = b_y.to(device)  # 设置模型为验证模式  model.eval()  output = model(b_x)  pre_lab = torch.argmax(output, dim=1)  loss = criterion(output, b_y)  val_loss += loss.item() * b_x.size(0)  val_corrects += torch.sum(pre_lab == b_y.data)  val_num += b_x.size(0)  # 计算并保存每一次迭代的loss值  train_loss_all.append(train_loss / train_num)  # 计算并保存训练集的准确率  train_acc_all.append(train_corrects.double().item() / train_num)  val_loss_all.append(val_loss / val_num)  val_acc_all.append(val_corrects.double().item() / val_num)  print('{} Train Loss:{:.4f} Train Acc:{:.4f}'.format(epoch, train_loss_all[-1], train_acc_all[-1]))  print('{} Val Loss:{:.4f} Val Acc: {:.4f}'.format(epoch, val_loss_all[-1], val_acc_all[-1]))  # 寻找最高准确度的权重  if val_acc_all[-1] > best_acc:  best_acc = val_acc_all[-1]  best_model_wts = copy.deepcopy(model.state_dict())  # 训练耗时  time_use = time.time() - since  print("训练耗费的时间:{:0f}m{:0f}s".format(time_use // 60, time_use % 60))  # 选择最优参数  # 加载最高准确率下的模型参数  torch.save(best_model_wts, 'E:/CODE/python/LeNet5/best_model.pth')  train_process = pd.DataFrame(data={"epoch": range(num_epochs),  "train_loss_all": train_loss_all,  "val_loss_all": val_loss_all,  "train_acc_all": train_acc_all,  "val_acc_all": val_acc_all})  return train_process

 准备

 一个迭代周期

初始化参数

对一批次的数据进行训练
遍历数据

for循环

for step, (b_x, b_y) in enumerate(train_dataloader): 是一个 for 循环语句的语法结构,用于迭代遍历一个可迭代对象 train_dataloader。 在每次循环迭代中,enumerate(train_dataloader) 将返回一个 (step, (b_x, b_y)) 的元组,其中: step 是当前迭代的索引值,表示当前是第几个迭代步骤。 (b_x, b_y) 是从 train_dataloader 中获取的一个批次的数据。

前向传播

模型的输出和标签计算损失函数

损失函数-----评估模型输出与真实标签之间的差异的函数

反向传播

更新网络并预测判断

 对一批次数据进行验证

注意

验证没有反向传播过程,因为验证数据在训练过程中主要用于评估模型的性能,而不是用于参数更新。在验证阶段,参数更新可能会导致模型在验证集上过拟合,并且会增加计算开销。因此,验证阶段只需要进行前向传播和损失计算,以获取模型在验证集上的性能指标,而不需要进行反向传播和参数更新。

一批次结束,计算并保存损失值和准确率

寻找最高准确度的权重

选择最优参数并返回

matplot_acc_lost

代码

def matplot_acc_lost(train_process):  plt.figure(figsize=(12, 4))  plt.subplot(1, 2, 1)  # 一行两列第一幅图  plt.plot(train_process["epoch"], train_process.train_loss_all, 'ro-', label="train loss")  plt.plot(train_process["epoch"], train_process.val_loss_all, 'bs', label="val loss")  plt.legend()  plt.xlabel("epoch")  plt.ylabel("loss")  plt.subplot(1, 2, 2)  # 一行两列第二幅图  plt.plot(train_process["epoch"], train_process.train_loss_all, 'ro-', label="train loss")  plt.plot(train_process["epoch"], train_process.val_loss_all, 'bs-', label="val loss")  plt.xlabel("epoch")  plt.ylabel("acc")  plt.legend()  plt.show()

 结果

modemodel_test.py

test_data_process

def test_data_process():  test_data = FashionMNIST(root='./data',  train=False,  transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),  # 转换成张量形式方便应用  download=True)  test_dataloader = Data.DataLoader(dataset=test_data,  batch_size=1,  shuffle=True,  num_workers=0)  return test_dataloader

test_model_process

def test_model_process(model, test_dataloader):  device = "cuda" if torch.cuda.is_available() else 'cpu'  model = model.to(device)  test_corrects=0.0  test_num=0  #只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度  with torch.no_grad():  for test_data_x,test_data_y in test_dataloader:  test_data_x=test_data_x.to(device)  test_data_y=test_data_y.to(device)  model.eval()  #前向传播过程,输入为测试数据集,输出为对每个样本的预测值  output=model(test_data_x)  #查找每一行中最大值对应的行标  pre_lab=torch.argmax(output,dim=1)  test_corrects += torch.sum(pre_lab==test_data_y.data)  test_num += test_data_x.size(0)  #计算测试准确率  test_acc=test_corrects.double().item() / test_num  print("测试的准确率为:",test_acc)

 torch.no_grad

torch.no_grad()是一个上下文管理器,用于在代码块中禁用梯度计算和参数更新。当进入torch.no_grad()的上下文中时,PyTorch会自动将requires_grad属性设置为False,从而禁止梯度的计算和参数的更新。

torch.no_grad()常用于评估模型或进行推断过程,不需要计算梯度的情况下,可以提高代码的执行效率并减少内存消耗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/731635.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

河北省光伏展

光伏展是指光伏行业的展览会,也被称为太阳能展。光伏展一般是由光伏企业、科研机构、行业协会和专业展览公司等共同举办的。展会内容包括光伏产品、技术、设备、材料、应用等方面的展示和交流。 光伏展通常是光伏行业的重要盛事,吸引了全球范围内的光伏企…

npm镜像源地址

镜像源地址替换问题(重要) 2024 年 1 月 22 日 ,registry.npm.taobao.org 的 SSL 证书正式过期。 2022 年 5 月 淘宝源发布了公告: (大家应该没有太多关注哦,也包括我,哈哈) &am…

144.乐理基础-根三五音、大三和弦、小三和弦

内容参考于: 三分钟音乐社 上一个内容:143.乐理基础-和弦是什么?和声是什么?三和弦-CSDN博客 必须先看上一个内容,了解什么是和弦、什么是和声,以及三和弦的定义 上一个内容最后写了三和弦的定义&#x…

【C++ 学习】构造函数详解!!!

1. 类的6个默认成员函数的引入 ① 如果一个类中什么成员都没有,简称为空类。 ② 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 ③ 默认成员函数:用户没有显式实现&…

嵌入式学习第二十五天!(网络的概念、UDP编程)

网络: 可以用来:数据传输、数据共享 1. 网络协议模型: 1. OSI协议模型: 应用层实际收发的数据表示层发送的数据是否加密会话层是否建立会话连接传输层数据传输的方式(数据包,流式)网络层数据的…

基于YOLOv8深度学习的智能道路裂缝检测与分析系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标检测、目标分割

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

外包干了5天,技术退步明显。。。。。

在湖南的一个安静角落,我,一个普通的大专生,开始了我的软件测试之旅。四年的外包生涯,让我在舒适区里逐渐失去了锐气,技术停滞不前,仿佛被时间遗忘。然而,生活的转机总是在不经意间降临。 与女…

家庭关怀视角下对待患病亲人态度的重要性——评析母对病父大声呵斥的行为现象

在家庭生活中,尤其是面对家人身体不适或疾病困扰的时候,我们的态度和行为方式显得尤为重要。近期,社会上存在一种令人忧虑的现象,即某些家庭中,母亲因压力或其他原因对生病的父亲表现出不耐烦甚至大吼大叫的态度。这种…

警用移动执法远程视频监控方案:安防视频监控系统EasyCVR+4G/5G移动执法仪

一、背景需求 在现代城市管理中,移动执法仪视频监控方案正逐渐成为一种高效、便捷的管理工具。该方案通过结合移动执法仪和视频监控技术,实现了对城市管理现场的实时监控和取证,有效提升了城市管理水平和效率。 移动执法仪作为现场执法的重…

TypeScript 哲学 - Object Types

readonly 修饰对象和数组的 双向可分配性是不同的 只有有一个可选属性不是意味着必须 不能传空对象,:这个例子(两个属性可选)而是如果对象有额外属性,那么必须至少加一个 可选属性。只要你在传递的值和目标类型有一个…

大模型概念解析 | Prompt Engineering

注1:本文系"概念解析"系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:大模型中的Prompt Engineering 大模型概念解析 | Prompt Engineering 第一部分 通俗解释 在人工智能的世界里,有一群被称为大模型的巨无霸。它们就像是知识的海绵…

关于STM32G070RBTx单片机使用HAL库往flash写数据的过程中死机问题

1.单片机型号:STM32G070RBTx 2.出现的问题 根据库函数FLASH_If_Write()的使用,我们分析往flash写数据的过程是把uint8_t 类型的数据(p_data)以地址的形式强转成uint64类型的,在一包128字节的数据时一次存储8位,存16次(packet_size/8)&#x…

Java项目:基于SSM框架实现的二手车交易平台【源码+开题报告+任务书+毕业论文+答辩ppt】

一、项目简介 本项目是一套基于SSM框架实现的二手车交易平台 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、功能齐…

MySQL底层原理

1. 请解释MySQL的逻辑架构和物理架构。 MySQL的逻辑架构和物理架构涉及到多个层面,包括网络连接、服务处理、存储引擎以及数据存储等部分。具体如下: 逻辑架构: 连接层(Connection Layer):客户端通过TCP…

瑞芯微 | I2S-音频基础 -1

最近调试音频驱动,顺便整理学习了一下i2s、alsa相关知识,整理成了几篇文章,后续会陆续更新。 喜欢嵌入式、Li怒晓得老铁可以关注一口君账号。 1. 音频常用术语 名称含义ADC(Analog to Digit Conversion)模拟信号转换…

Android中Fragment的onResume方法的介绍、执行时机,以及不执行回调的异常情况分析

onResume()是Fragment生命周期中的一个重要方法,表示Fragment已经获取焦点并开始与用户交互。在onResume()方法中,Fragment通常完成与用户界面交互的准备工作,比如开始执行一些动画、加载数据或注册监听器等。 1. 回调时机: onRe…

stm32普通定时器脉冲计数(发送固定脉冲个数),控制步进电机驱动器

拨码开关设置驱动器,细分 方法思路:用通用定时器TIM2,1ms产生一次中断;在中断里做IO反转; 发送10个脉冲信号

系统架构设计师考试大纲

一、系统架构设计综合知识 1. 计算机系统基本知识 1.1 计算机系统概述 1.2 计算机硬件 1.2.1 计算机硬件组成 1.2.2 处理器 1.2.3 存储器 1.2.4 总线 1.2.5 接口 1.2.6 外部设备 1.3 计算机软件 1.3.1 计算机软件概述 1.3.2 操作系统 1.3.3 数据库系统 1.3.4 文件系统 1.3.5 网…

搬家微信小程序:便捷预约,轻松解决搬家难题

在快节奏的现代生活中,搬家成为许多人不得不面对的一项繁琐任务。从整理物品、联系搬家公司,到现场协调,每一个环节都让人倍感压力。然而,如今随着科技的不断发展,搬家微信小程序的出现,为这一难题带来了便…

示波器探头的使用

无源探头(Tektronix P2220) 阻抗:1Mhz 衰减:10:1/1:1(与探头上的档位X10/X1相关,如果探头没有档位默认为10:1) 探头型号:电压 高压差分探头(Tektronix P5200A) 阻抗:1Mhz 衰减:50:1/500:1(…