【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)

文章目录

      • 0. 前言
      • 1. Cifar10数据集
        • 1.1 Cifar10数据集下载
        • 1.2 Cifar10数据集解析
      • 2. LeNet5网络
        • 2.1 LeNet5的网络结构
        • 2.2 基于PyTorch的LeNet5网络编码
      • 3. LeNet5网络训练及输出验证
        • 3.1 LeNet5网络训练
        • 3.2 LeNet5网络验证
      • 4. 完整代码
        • 4.1 训练代码
        • 4.1 验证代码

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文是基于PyTorch框架使用LeNet5网络实现图像分类的实战演练,训练的数据集采用Cifar10,旨在通过实操强化对深度学习尤其是卷积神经元网络的理解。

本文是一个完整的保姆级学习指引,只要具备最基础的深度学习知识就可以通过本文的指引:使用PyTorch库从零搭建LeNet5网络,然后对其进行训练,最后能够识别实拍图像中的实物。

1. Cifar10数据集

Cifar10数据集由计算机科学家Geoffrey Hinton的学生Alex Krizhevsky、Ilya Sutskever 在1990年代创建。Cifar10是一个包含10个类别的图像分类数据集,每个类别包含6000张32x32像素的彩色图像,总计60000张图像,其中50000个图像用于训练网络模型(训练组),10000个图像用于验证网络模型(验证组)。

其名字Cifar10代表Canadian Institute for Advanced Research(加拿大高级研究所)做的10种分类的图像集,后面的Cifar100则是100种分类的图像集。

1.1 Cifar10数据集下载

使用torchvision直接下载Cifar10:

from torchvision import datasets
from torchvision import transformsdata_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=True,transform=transforms.ToTensor())   #首次下载时download设为true

datasets.CIFAR10中的参数:

  • root:下载文件的路径
  • train:如果为True,则是下载训练组数据,总计50000张图像;如果为False,则是下载验证组数据,总计10000张图像
  • download:新下载时需要设定为True,如果已经下载好数据可以设定为False
  • transform:对图像数据进行变形,这里指定为transforms.ToTensor()图像数据会被转换为Tensor,数据范围调整到0~1,省得我们再写一行归一化代码了
1.2 Cifar10数据集解析

下载之后可以看一下Cifar10数据集的具体内容:

print(type(cifar10))
print(cifar10[0])
------------------------输出------------------------------------
<class 'torchvision.datasets.cifar.CIFAR10'>
(tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],[0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],[0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],...,[0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],[0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],[0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],[[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],[0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],[0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],...,[0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],[0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],[0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],[[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],[0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],[0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],...,[0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],[0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],[0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]]), 6)Process finished with exit code 0

可以见到Cifar10有其单独的数据类型torchvision.datasets.cifar.CIFAR10,其结构类似list。

如果输出其中某一元素,例如第一个cifar10[0],其中包含:

  • 一个维度为[3,32,32]的tensor(因为上面Transform已经指定了ToTensor),这个就是RGB三通道的图像数据
  • 一个标量数据label,这里是6,这个数据代表图像的真实分类,其对应关系如下表:
    在这里插入图片描述

这里我们也可以用matplotlib把图像的tensor数据转回图像,看看这个label为6的图像究竟是什么样的:

from torchvision import datasets
import matplotlib.pyplot as plt
from torchvision import transformsdata_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为true# print(type(cifar10))
# print(cifar10[0])img,label = cifar10[0]
plt.imshow(img.permute(1,2,0))
plt.show()

输出为:
在这里插入图片描述
没错,这是一个label为6的Frog,32×32像素的图像就只能做到这个程度了。

这里使用了.permute()是因为原始数据的维度是[channel3, H32, W32],而.imshow()要求的输入维度应该是[H, W, channel],需要调整下原始数据的维度顺序。

2. LeNet5网络

LeNet5是由Yann LeCun在20世纪90年代初提出,是一个经典的卷积神经网络。LeNet5由7层神经网络组成,包括2个卷积层、2个池化层和3个全连接层。其(在当时的时代背景下)创造性地使用了卷积层和池化层对输入进行特征提取,减少了参数数量,同时增强了网络对输入图像的平移和旋转不变性。

LeNet5被广泛应用于手写数字识别,也可用于其他图像分类任务。虽然现在的深度卷积神经网络比LeNet5有更好的性能,但LeNet5对于学习卷积神经网络的基本原理和方法具有重要的教育意义

2.1 LeNet5的网络结构

LeNet5的网络结构如下图:
请添加图片描述

LeNet5的输入为32x32的图像:

  • 第一层为一个卷积层,包含6个5x5的卷积核,输出的特征图为28x28
  • 第二层为一个2x2的最大池化层,将特征图大小缩小一半14×14
  • 第三层为另一个卷积层,包含16个5x5的卷积核,输出的特征图为10x10
  • 第四层同第二层,将特征图大小缩小一半5×5
  • 第五层为一个全连接层,含有120个神经元
  • 第六层为另一个全连接层,含有84个神经元
  • 最后一层为输出层,包含10个神经元,每个神经元对应一个label
2.2 基于PyTorch的LeNet5网络编码

根据上文LeNet5的网络结构,编写代码如下:

import torch.nn as nnclass LeNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3#输出张量为 Batch(1)*Channel(6)*H(28)*W(28)nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),# 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),# 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),# 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),# 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)nn.Flatten(),# 将输出一维化,用于后面的全连接网络输入nn.Linear(120, 84),nn.Sigmoid(),nn.Linear(84, 10))def forward(self, x):return self.net(x)

3. LeNet5网络训练及输出验证

3.1 LeNet5网络训练

碍于我的电脑没有GPU,使用CPU版PyTorch数据训练非常慢,我只取了Cifar10的前2000个数据进行训练 (T_T)

small_cifar10 = []
for i in range(2000):small_cifar10.append(cifar10[i])

训练相关设置如下:

  • 损失函数:交叉熵损失函数nn.CrossEntropyLoss()
  • 优化方式:随机梯度下降torch.optim.SGD()
  • epoch与learning rate:这是比较头疼的地方,目前我没有探索出太好的方式能在初期就把epoch和lr设定的比较好,只能进行逐步尝试。为了不浪费每次训练,我们可以把每次训练的权重保存下来,下次训练基于上次的结果进行。保存和加载权重的方式可以参考往期博客:通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()。下图展示了我的探索过程:lr的取值大约从1e-5逐步降低到2e-7,epoch总计大概有3000左右,loss值由初始的10000左右下降到100内。

这一块的训练过程忘记完整记录每一步的详细参数(epoch和lr)了,如果你有需要可以留下邮箱,我把训练好的权重发给你。读者也可以探索更好的训练参数。

在这里插入图片描述

3.2 LeNet5网络验证

激动人心的时刻来了!现在来验证我们训练好的网络能否准确识别目标图像!

我选用的图像是小鹏汽车在2023年上市的G6车型进行验证,图像如下:
在这里插入图片描述
加载我们训练好的权重文件,把图像输入到模型中:

def img_totensor(img_file):img = Image.open(img_file)transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度return img_tensortest_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

最终输出如下:

tensor([[ 8.4051, 12.0952, -7.9274,  0.3868, -3.0866, -4.7883, -1.6089, -3.6484,-1.1387,  4.7348]], grad_fn=<AddmmBackward0>)
tensor([[-1.1992, 17.4531, -2.7929, -6.0410, -1.7589, -2.6942, -3.6753, -2.6800,3.6378,  2.4267]], grad_fn=<AddmmBackward0>)
tensor([[ 1.7580, 10.6321, -5.3922, -0.4557, -2.0147, -0.5974, -0.5785, -4.7977,-1.2916,  5.4786]], grad_fn=<AddmmBackward0>)
tensor([[10.5689,  6.2413, -0.9554, -4.4162,  1.0807, -7.9541, -5.3185, -6.0609,5.1129,  4.2243]], grad_fn=<AddmmBackward0>)

我们来解读一下这个输出:

  • 第1、2、3个图像对应输出tensor最大值在第[1]个元素(从0开始计数),即对应label值为1,真实分类为Car,预测正确。
  • 第4个图像的输出预测错误,最大值在第[0]个元素,LeNet5认为这个图像是Airplane。

这个准确率虽然不算高,但是别忘了我仅仅使用了Cifar10的前2000个数据进行训练;而且LeNet5网络输入为32×32大小的图像,例如上面的青蛙,即使让人来分辨也是挺困难的任务。

4. 完整代码

4.1 训练代码
#文件命名为 CIFAR10_main.py 后面验证时需要调用
from torchvision import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdmdata_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为trueclass LeNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3#输出张量为 Batch(1)*Channel(6)*H(28)*W(28)nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),# 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),# 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2),# 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),# 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)nn.Flatten(),# 将输出一维化,用于后面的全连接网络输入nn.Linear(120, 84),nn.Sigmoid(),nn.Linear(84, 10))def forward(self, x):return self.net(x)if __name__ == '__main__':model = LeNet()model.load_state_dict(torch.load('CIFAR10/small2000_7.pth'))loss = nn.CrossEntropyLoss()opt = torch.optim.SGD(model.parameters(),lr=2e-7)small_cifar10 = []for i in range(2000):small_cifar10.append(cifar10[i])for epoch in range(1000):opt.zero_grad()total_loss = torch.tensor([0])for img,label in tqdm(small_cifar10):output = model(img.unsqueeze(0))label = torch.tensor([label])LeNet_loss = loss(output, label)total_loss = total_loss + LeNet_lossLeNet_loss.backward()opt.step()total_loss_numpy = total_loss.detach().numpy()plt.scatter(epoch,total_loss_numpy,c='b')print(total_loss)print("epoch=",epoch)torch.save(model.state_dict(),'CIFAR10/small2000_8.pth')plt.show()
4.1 验证代码
import torch
from torchvision import transforms
from PIL import Image
from CIFAR10_main import LeNetdef img_totensor(img_file):img = Image.open(img_file)transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度return img_tensortest_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/91491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

紫光同创FPGA图像视频采集系统,基于OV7725实现,提供工程源码和技术支持

目录 1、前言免责声明 2、设计思路框架视频源选择OV7725摄像头配置及采集动态彩条HDMA图像缓存输入输出视频HDMA缓冲FIFOHDMA控制模块HDMI输出 3、PDS工程详解4、上板调试验证并演示准备工作静态演示动态演示 5、福利&#xff1a;工程源码获取 紫光同创FPGA图像视频采集系统&am…

mysql面试题7:MySQL事务原理是什么?MySQL事务的隔离级别有哪些?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:MySQL事务原理是什么? MySQL事务的原理是基于ACID(原子性、一致性、隔离性、持久性)特性来实现的,具体原理如下: Atomicity(原子性):事务…

Java使用Scanner类实现用户输入与交互

概述&#xff1a; Scanner类是Java中的一个重要工具类&#xff0c;用于读取用户的输入。它提供了一系列的方法&#xff0c;可以方便地读取不同类型的数据&#xff0c;如整数、浮点数、字符串等。在本文中&#xff0c;我们将详细介绍Scanner类的使用方法&#xff0c;并通过两个…

如何做一个基于 Python 的搜索引擎?

怎么做一个基于 python 的搜索引擎&#xff1f; 1、确定搜索引擎范围和目标用户 在决定做一个基于Python的搜索引擎之前&#xff0c;首先需要确定搜索引擎的范围和目标用户。搜索引擎的范围可以包括新闻、商品、音乐等&#xff0c;不同的领域需要不同的数据来源和处理方式。同…

给奶牛做直播之三

​一、前言 上一篇给牛奶做直播之二 主要讲用RTMP搭建点播服务器&#xff0c;整了半天直播还没上场&#xff0c;今天不讲太多理论的玩意&#xff0c;奶牛今天放假了也不出场&#xff0c;就由本人亲自上场来个直播首秀&#xff0c;见下图&#xff0c;如果有兴趣的话&#xff0…

YOLOV8-DET转ONNX和RKNN

目录 1. 前言 2.环境配置 (1) RK3588开发板Python环境 (2) PC转onnx和rknn的环境 3.PT模型转onnx 4. ONNX模型转RKNN 6.测试结果 1. 前言 yolov8就不介绍了&#xff0c;详细的请见YOLOV8详细对比&#xff0c;本文章注重实际的使用&#xff0c;从拿到yolov8的pt检测模型&…

施耐德电气:勾勒未来工业愿景,赋能中国市场

9月19日&#xff0c;第23届中国国际工业博览会&#xff08;简称“工博会”&#xff09;在上海隆重召开。作为全球能源管理和自动化领域的数字化转型专家&#xff0c;施耐德电气在工博会现场全方位展现了自身对未来工业的全新视野与深刻见解&#xff0c;不仅展示了其贯通企业设计…

字节一面:深拷贝浅拷贝的区别?如何实现一个深拷贝?

前言 最近博主在字节面试中遇到这样一个面试题&#xff0c;这个问题也是前端面试的高频问题&#xff0c;我们经常需要对后端返回的数据进行处理才能渲染到页面上&#xff0c;一般我们会讲数据进行拷贝&#xff0c;在副本对象里进行处理&#xff0c;以免玷污原始数据&#xff0c…

arduino - UNO-R3,mega2560-R3,NUCLEO-H723ZG的arduino引脚定义区别

文章目录 arduino - UNO-R3,mega2560-R3,NUCLEO-H723ZG的引脚定义区别概述笔记NUCLEO-H723ZGmega2560-R3UNO-R3经过比对, 这2个板子(NUCLEO-H723ZG, mega2560-R3)都是和UNO-R3的arduino引脚定义一样的.mega2560-r3和NUCLEO-H723ZG的区别补充arduino uno r3的纯数字IO和模拟IO作…

了解和使用MinIO

MinIO 文章目录 MinIOMinIO简介MinIO概述 开箱使用基本概念 快速入门封装MinIO为starter在项目中集成 MinIO简介 MinIO 是一个开源的对象存储服务器&#xff0c;可以帮助用户构建高度可扩展的存储基础架构。它采用分布式架构&#xff0c;可以在多个节点上部署&#xff0c;实现…

uniapp iOS离线打包——原生工程配置

uniapp iOS离线打包&#xff0c;如何配置项目工程&#xff1f; 文章目录 uniapp iOS离线打包&#xff0c;如何配置项目工程&#xff1f;工程配置效果图DebugRelease 配置工程配置 Appkey应用图标模块及三方SDK配置未配置模块错误配置模块TIP: App iOS 离线打包 前提&#xff1a…

Linux服务器安装Anaconda 配置远程jupyter lab使用虚拟环境

参考的博客&#xff1a; Linux服务器安装Anaconda 并配置远程jupyter lab anaconda配置远程访问jupyter&#xff0c;并创建虚拟环境 理解和创建&#xff1a;Anaconda、Jupyterlab、虚拟环境、Kernel 下边是正文了。 https://www.anaconda.com/download是官网网址&#xff0c;可…

RabbitMQ配置

centos7安装rabbitmq 官网教程&#xff1a;https://www.rabbitmq.com/install-rpm.html#downloads 官网介绍了两种安装方法&#xff1a; 安装使用yum库中的包&#xff08;强烈建议此选项&#xff09;上Cloudsmith.io或PackageCloud 下载软件包并使用rpm安装它。此选项将需要手…

华为云云耀云服务器L实例评测|云耀云服务器L实例部署Linux管理面板mdserver-web

华为云云耀云服务器L实例评测&#xff5c;云耀云服务器L实例部署Linux管理面板mdserver-webl 一、云耀云服务器L实例介绍1.1 云耀云服务器L实例简介1.2 云耀云服务器L实例特点 二、mdserver-web介绍2.1 mdserver-web简介2.2 mdserver-web特点2.3 主要插件介绍 三、本次实践介绍…

机器学习之单层神经网络的训练:增量规则(Delta Rule)

文章目录 权重的调整单层神经网络使用delta规则的训练过程 神经网络以权值的形式存储信息,根据给定的信息来修改权值的系统方法称为学习规则。由于训练是神经网络系统地存储信息的唯一途径&#xff0c;因此学习规则是神经网络研究中的一个重要组成部分 权重的调整 &#xff08…

iPhone数据丢失怎么办?9 佳免费 iPhone 数据恢复软件可收藏

您是否知道有多种原因可能导致 iPhone 上存储的数据永久丢失&#xff1f;然而&#xff0c;使用一些最好的免费 iPhone 数据恢复软件&#xff0c;您仍然可以恢复它。 由于我们几乎总是保存手机上的所有内容&#xff08;从联系人到媒体文件&#xff09;&#xff0c;因此 iPhone …

【Segment Anything Model】SAM做多类别分割,医疗语义分割

🍉 博主微信 cvxiayixiao 🍓 【Segment Anything Model】计算机视觉检测分割任务专栏。 链接 🍑 【公开数据集预处理】特别是医疗公开数据集的接受和预处理,提供代码讲解。链接 🍈 【opencv+图像处理】opencv代码库讲解,结合图像处理知识,不仅仅是调库。链接 文章目…

CDH 6.3.2升级Flink到1.17.1版本

CDH&#xff1a;6.3.2 原来的Flink&#xff1a;1.12 要升级的Flink&#xff1a;1.17.1 操作系统&#xff1a;CentOS Linux 7 一、Flink1.17编译 build.sh文件&#xff1a; #!/bin/bash set -x set -e set -vFLINK_URLsed /^FLINK_URL/!d;s/.*// flink-parcel.properties FLI…

龙迅LT9611UXC 2PORT MIPICSI/DSI转HDMI(2.0)转换器+音频,内置MCU

龙迅LT9611UXC 1.描述&#xff1a; LT9611UXC是一个高性能的MIPI DSI/CSI到HDMI2.0转换器。MIPI DSI/CSI输入具有可配置的单 端口或双端口&#xff0c;1高速时钟通道和1~4高速数据通道&#xff0c;最大2Gbps/通道&#xff0c;可支持高达16Gbps的总带 宽。LT9611UXC支持突发…

CISSP学习笔记:业务连续性计划

第三章 业务连续性计划 3.1 业务连续性计划 业务连续性计划(BCP): 对组织各种过程的风险评估&#xff0c;发生风险的情况下为了使风险对组织的影响降至最小而定制的各种计划BCP和DRP首先考虑的人不受伤害&#xff0c;然后再解决IT恢复和还原问题BCP的主要步骤&#xff1a; 项…