【保姆级教程】使用 PyTorch 自定义卷积神经网络(CNN) 实现图像分类、训练验证、预测全流程【附数据集与源码】

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
23.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】24.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
25.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】26.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
27.【基于YOLOv8深度学习的人脸面部表情识别系统】28.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】
29.【基于YOLOv8深度学习的智能肺炎诊断系统】30.【基于YOLOv8深度学习的葡萄簇目标检测系统】
31.【基于YOLOv8深度学习的100种中草药智能识别系统】32.【基于YOLOv8深度学习的102种花卉智能识别系统】
33.【基于YOLOv8深度学习的100种蝴蝶智能识别系统】34.【基于YOLOv8深度学习的水稻叶片病害智能诊断系统】
35.【基于YOLOv8与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于YOLOv8深度学习的智能草莓病害检测与分割系统】
37.【基于YOLOv8深度学习的复杂场景下船舶目标检测系统】38.【基于YOLOv8深度学习的农作物幼苗与杂草检测系统】
39.【基于YOLOv8深度学习的智能道路裂缝检测与分析系统】40.【基于YOLOv8深度学习的葡萄病害智能诊断与防治系统】
41.【基于YOLOv8深度学习的遥感地理空间物体检测系统】42.【基于YOLOv8深度学习的无人机视角地面物体检测系统】
43.【基于YOLOv8深度学习的木薯病害智能诊断与防治系统】44.【基于YOLOv8深度学习的野外火焰烟雾检测系统】
45.【基于YOLOv8深度学习的脑肿瘤智能检测系统】46.【基于YOLOv8深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于YOLOv8深度学习的橙子病害智能诊断与防治系统】48.【车辆检测追踪与流量计数系统】
49.【行人检测追踪与双向流量计数系统】50.【基于YOLOv8深度学习的反光衣检测与预警系统】
51.【危险区域人员闯入检测与报警系统】52.【高密度人脸智能检测与统计系统】
53.【CT扫描图像肾结石智能检测系统】54.【水果智能检测系统】
55.【水果质量好坏智能检测系统】56.【蔬菜目标检测与识别系统】
57.【非机动车驾驶员头盔检测系统】58.【太阳能电池板检测与分析系统】
59.【工业螺栓螺母检测】60.【金属焊缝缺陷检测系统】
61.【链条缺陷检测与识别系统】62.【交通信号灯检测识别】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

    • 引言
    • 环境配置
    • 详细步骤
    • 结论:

引言

卷积神经网络(也称为 CNN 或 ConvNet)由 Yann LeCun 于 20 世纪 80 年代提出,至今已取得长足进步。基于 CNN 的架构不仅用于简单的数字分类任务,还被广泛用于许多深度学习和计算机视觉相关任务,例如对象检测、图像分割、注视跟踪等。本文将使用 PyTorch 框架在流行的 CIFAR-10 数据集上实现基于 CNN 的图像分类器。

环境配置

为了实现 CNN 并下载 CIFAR-10 数据集,我们需要torchtorchvision模块。除此之外,我们还将使用 numpy 和 matplotlib 进行数据分析和绘图。可以使用以下命令使用 pip 包管理器安装所需的库:

pip install torch torchvision torchaudio numpy matplotlib

详细步骤

步骤1:下载数据并从训练集中打印一些示例图像。

  • 在开始进行 模型训练 之前,我们首先需要将数据集下载到本地机器上,我们将在该机器上训练我们的模型。为此,我们将使用 torchvision库,将 CIFAR-10 数据集分别下载到目录“./CIFAR10/train”和“./CIFAR10/test ”中的训练和测试集中。我们还应用了图像变换,该过程在所有图像的三个通道上完成。
  • 现在,我们有一个训练数据集和一个测试数据集,分别包含 50000 和 10000 张图像,尺寸为 32x32x3。之后,我们将这些数据集转换为批大小为 128 的数据加载器,以实现更好的泛化和更快的训练过程。
  • 最后,我们从第一批训练数据中绘制出一些样本图像,以了解我们使用torchvision的**make_grid实用程序处理的图像。

具体代码如下:

import torch 
import torchvision 
import matplotlib.pyplot as plt 
import numpy as np # The below two lines are optional and are just there to avoid any SSL 
# related errors while downloading the CIFAR-10 dataset 
import ssl 
ssl._create_default_https_context = ssl._create_unverified_context #Defining plotting settings 
plt.rcParams['figure.figsize'] = 14, 6#Initializing normalizing transform for the dataset 
normalize_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))]) #Downloading the CIFAR10 dataset into train and test sets 
train_dataset = torchvision.datasets.CIFAR10( root="./CIFAR10/train", train=True, transform=normalize_transform, download=True) test_dataset = torchvision.datasets.CIFAR10( root="./CIFAR10/test", train=False, transform=normalize_transform, download=True) #Generating data loaders from the corresponding datasets 
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) 
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) #Plotting 25 images from the 1st batch 
dataiter = iter(train_loader) 
images, labels = dataiter.next() 
plt.imshow(np.transpose(torchvision.utils.make_grid( 
images[:25], normalize=True, padding=1, nrow=5).numpy(), (1, 2, 0))) 
plt.axis('off')

输出:

img

图 1:来自训练数据集的一些示例图像

步骤 2:绘制数据集的类别分布

绘制训练集的类别分布通常是一个好主意。**这有助于检查所提供的数据集是否平衡。**为此,我们分批迭代整个训练集并收集每个实例的相应类别。最后,我们计算唯一类别的数量并绘制它们。

代码如下:

#Iterating over the training dataset and storing the target class for each sample 
classes = [] 
for batch_idx, data in enumerate(train_loader, 0): x, y = data classes.extend(y.tolist()) #Calculating the unique classes and the respective counts and plotting them 
unique, counts = np.unique(classes, return_counts=True) 
names = list(test_dataset.class_to_idx.keys()) 
plt.bar(names, counts) 
plt.xlabel("Target Classes") 
plt.ylabel("Number of training instances")

输出:

img

图 2:训练集的类别分布

如图 2 所示,这十个类别中的每一个类别都有几乎相同数量的训练样本。因此,我们不需要采取额外的步骤来重新平衡数据集。

步骤3:实现CNN架构

在架构方面,我们将使用一个简单的模型,该模型采用三个深度分别为32、64 和 64 的卷积层,然后是两个完全连接层以执行分类。

  • 每个卷积层都涉及一个包含3×3 卷积滤波器的卷积操作,然后是 ReLU 激活操作以将非线性引入系统,以及带有 2×2 滤波器的最大池化操作以降低特征图的维数。
  • 在卷积块结束后,我们将多维层扁平化为低维结构,以开始我们的分类块。在第一个线性层之后,最后一个输出层(也是线性层)针对我们数据集中的十个唯一类别中的每一个都有十个神经元。

架构如下:

img

图 3:CNN 的架构图

为了构建我们的模型,我们将创建一个从torch.nn.Module类继承的CNN 类,以利用 Pytorch 实用程序。除此之外,我们将使用torch.nn.Sequential容器将我们的层一个接一个地组合起来。

  • Conv2D ()、ReLU()和MaxPool2D()层执行卷积、激活和池化操作。我们使用 1 的填充来为内核提供足够的学习空间,因为填充为图像提供了更多的覆盖区域,尤其是外框中的像素。
  • 卷积块之后,***Linear()***全连接层执行分类。

代码:

class CNN(torch.nn.Module): def __init__(self): super().__init__() self.model = torch.nn.Sequential( #Input = 3 x 32 x 32, Output = 32 x 32 x 32 torch.nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), torch.nn.ReLU(), #Input = 32 x 32 x 32, Output = 32 x 16 x 16 torch.nn.MaxPool2d(kernel_size=2), #Input = 32 x 16 x 16, Output = 64 x 16 x 16 torch.nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1), torch.nn.ReLU(), #Input = 64 x 16 x 16, Output = 64 x 8 x 8 torch.nn.MaxPool2d(kernel_size=2), #Input = 64 x 8 x 8, Output = 64 x 8 x 8 torch.nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1), torch.nn.ReLU(), #Input = 64 x 8 x 8, Output = 64 x 4 x 4 torch.nn.MaxPool2d(kernel_size=2), torch.nn.Flatten(), torch.nn.Linear(64*4*4, 512), torch.nn.ReLU(), torch.nn.Linear(512, 10) ) def forward(self, x): return self.model(x) 

步骤 4:定义训练参数并开始训练

我们通过选择训练模型的设备(即 CPU 或 GPU)来开始训练过程。然后,我们定义模型超参数,如下所示:

  • 我们对模型进行了50 轮训练,由于我们有一个多类问题,我们使用交叉熵损失作为我们的目标函数
  • 我们采用了流行的Adam 优化器学习率为 0.001,**weight_decay 为 0.01,**通过正则化来优化目标函数,防止过度拟合。

最后,我们开始训练循环,其中包括计算每个批次的输出,并通过将预测标签与真实标签进行比较来计算损失。最后,我们绘制了每个时期的训练损失,以确保训练过程按计划进行。

代码:

#Selecting the appropriate training device 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNN().to(device) #Defining the model hyper parameters 
num_epochs = 50
learning_rate = 0.001
weight_decay = 0.01
criterion = torch.nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) #Training process begins 
train_loss_list = [] 
for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}:', end = ' ') train_loss = 0#Iterating over the training dataset in batches model.train() for i, (images, labels) in enumerate(train_loader): #Extracting images and target labels for the batch being iterated images = images.to(device) labels = labels.to(device) #Calculating the model output and the cross entropy loss outputs = model(images) loss = criterion(outputs, labels) #Updating weights according to calculated loss optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() #Printing loss for each epoch train_loss_list.append(train_loss/len(train_loader)) print(f"Training loss = {train_loss_list[-1]}") #Plotting loss for all epochs 
plt.plot(range(1,num_epochs+1), train_loss_list) 
plt.xlabel("Number of epochs") 
plt.ylabel("Training loss") 

输出:

img

图 4:训练损失与周期数的关系图

从图 4 中我们可以看出,随着训练次数的增加,损失逐渐减少,这表明训练过程成功。

步骤 5:计算模型在测试集上的准确率

现在我们的模型已经训练完毕,我们需要检查它在测试集上的表现。为此,我们分批迭代整个测试集,并通过比较每个批次的真实标签和预测标签来计算准确率。

代码:

test_acc=0
model.eval() with torch.no_grad(): #Iterating over the training dataset in batches for i, (images, labels) in enumerate(test_loader): images = images.to(device) y_true = labels.to(device) #Calculating outputs for the batch being iterated outputs = model(images) #Calculated prediction labels from models _, y_pred = torch.max(outputs.data, 1) #Comparing predicted and true labels test_acc += (y_pred == y_true).sum().item() print(f"Test set accuracy = {100 * test_acc / len(test_dataset)} %")

输出:

img

图 5:测试集上的准确率

步骤 6:生成测试集中样本图像的预测

如图 5 所示,我们的模型已达到近 72% 的准确率。为了验证其性能,我们可以为一些样本图像生成一些预测。为此,我们获取测试集最后一批的前五张图像,并使用torchvision 的make_grid实用程序绘制它们。然后,我们从模型中收集它们的真实标签和预测,并将它们显示在图的标题中。

代码:

#Generating predictions for 'num_images' amount of images from the last batch of test set 
num_images = 5
y_true_name = [names[y_true[idx]] for idx in range(num_images)] 
y_pred_name = [names[y_pred[idx]] for idx in range(num_images)] #Generating the title for the plot 
title = f"Actual labels: {y_true_name}, Predicted labels: {y_pred_name}"#Finally plotting the images with their actual and predicted labels in the title 
plt.imshow(np.transpose(torchvision.utils.make_grid(images[:num_images].cpu(), normalize=True, padding=1).numpy(), (1, 2, 0))) 
plt.title(title) 
plt.axis("off")

输出:

img

图 6:测试集中 5 个样本图像的实际标签与预测标签。请注意,标签的顺序与相应图像的顺序相同,从左到右。

从图 6 可以看出,除了第二张图像外,该模型对所有图像都产生了正确的预测,因为它将狗错误地归类为猫!

结论:

本文介绍了在流行的 CIFAR-10 数据集上用 PyTorch 实现的简单 CNN。我们还可以尝试网络架构和模型超参数,以进一步提高模型准确性!


关注文末名片G-Z-H:【阿旭算法与机器学习】,发送【开源】可获取更多学习资源

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53012.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Camera系列(二):TextureView+Camera

两岸猿声啼不住,轻舟已过万重山—李白 Android Camera系列(一):SurfaceViewCamera Android Camera系列(二):TextureViewCamera Android Camera系列(三):GLS…

设计模式学习-简单的命令模式例子

上一章节介绍过了命令模式,这一篇文章就简单的做一个小案例来巩固学习 搭建场景 简单的搭建一个场景就行 ,随便准备一个物体放在场景中位置Reset一下即可。 代码编写 定义接口(或者抽象类)ICommand 用来规范Command的行为。注意…

SpringCloud开发实战(三):集成Eureka注册中心

目录 SpringCloud开发实战(一):搭建SpringCloud框架 SpringCloud开发实战(二):通过RestTemplate实现远程调用 Eureka简介 Eureka 是一个基于 Java 的开源技术,最广为人知的是作为 Netflix 开发…

Java 输入与输出之 NIO.2【AIO】【Path、Paths、Files】【walkFileTree接口】探索之【三】

在JDK 1.7 版本中对NIO进行了完善,推出了NIO.2,也称为AIO(异步IO),在处理大量并发请求时具有优势,特别是在网络编程和高并发场景下,表现得更为出色。 对于输出流和输入流而言,操作的…

【GIS开发小课堂】vue3+Cesium.js三维WebGIS项目实战(一)

随着市场对数字孪生的需求日益增多,对于前端从业者的能力从对框架vue、react的要求,逐步扩展到2D、3D空间的交互,为用户提供更紧密的立体交互。近年来前端对GIS的需求日益增多。 本文档详细介绍了使用Vue3和Cesium.js构建三维WebGIS项目的步骤…

024集—— 正则表达式、replace、DateTime日期的用法——C#学习笔记

DateTime 是一个struct结构体。 代码如下: using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace ConsoleApp1 {internal class Program{static void Main(string[] args){args new s…

神策SDK不支持Windows客户端全埋点,怎么实现用户统计分析?

本文将介绍,ClkLog针对神策不支持全埋点的客户端实现用户访问基础统计分析 1。 客户遇到的问题 ClkLog的用户访问基础统计分析功能是基于神策SDK的全埋点来实现的。 我们遇到有些客户是使用C、C#等语言来开发的客户端,然而神策此类SDK(如C, C…

psql常见报错解决

问题 解决 要在管理员模式下启动 pg_ctl start -D "D:\Program\PostgreSQL\data" 注册成服务 D:\Program\PostgreSQL\bin\pg_ctl.exe register -N "postgresql" -D "D:\Program\PostgreSQL\data" -U "postgres" -P "postgre…

守护夏日清凉:EasyCVR+AI视频智能管理方案为水上乐园安全保驾护航

随着夏季的来临,水上乐园成为了人们避暑消夏、亲子互动的理想去处。然而,随着游客量的激增,如何确保水上乐园的安全与秩序,提升游客体验,成为了管理者亟待解决的问题。为此,引入一套高效、智能的视频监控方…

workman和GateWay学习笔记

前言 workman支持Udp GateWay是基于workman的二次封装,更适合长链接场景 windows安装workman composer create-project workerman/webman windows运行workman cd webman php windows.php windows访问 http://ip地址:8787 将workman引入thinkphp框架理念

OZON户外运动产品有哪些好卖的

Top1 运动水壶 Спортивная бутылка 780 мл 商品id:1613789852 月销量:819 OZON热销文具产品:m6z.cn/5H6fQR (复制浏览器打开) 780毫升的容量设计,既不会过于笨重,也能满足用户在运动或户外活…

【自动驾驶】决策规划算法 | 数学基础(二)凸优化与非凸优化

写在前面: 🌟 欢迎光临 清流君 的博客小天地,这里是我分享技术与心得的温馨角落。📝 个人主页:清流君_CSDN博客,期待与您一同探索 移动机器人 领域的无限可能。 🔍 本文系 清流君 原创之作&…

初识Linux · 有关makefile

目录 前言: 1 makefile的简单使用 2 makefile介绍 前言: 我们上文介绍了gcc和g的基本使用,带了许多的子指令,但是有的时候啊,一个一个敲指令确实有点麻烦了,此时,一个工具就能派上用场&…

Linux学习-虚拟化平台安装和使用

注:系统使用Rock8.6 下载链接 通过百度网盘分享的文件:cirros.qcow2,node_base.xml等2个文件 链接:https://pan.baidu.com/s/1hupGQsMjrXMgngCy3lQLhw?pwdhlr6 提取码:hlr6[rootharbor ~]# cat /etc/redhat-releas…

Django+Vue宠物服务管理系统的设计与实现

目录 1 项目介绍2 项目截图3 核心代码3.1 需要的环境3.2 Django接口层3.3 实体类3.4 config.ini3.5 启动类3.5 Vue 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍:CSDN认证博客专家,CSDN平台Java领域优质创作者&…

2024-09-03作业

作业结果 作业代码 #include <iostream> using namespace std; class RMB { friend const RMB operator-(const RMB &L,const RMB &R); friend const RMB operator--(RMB &O,int); private: int yuan; int jiao; int fen; static…

JMeter 安装使用

JMeter 安装使用 a.安装 下载链接:Apache JMeter - Download Apache JMeter 环境变量 打开 cmd 输入 jmeter&#xff0c;即可启动 b.使用 http请求接口 300 个线程设置 1 s 的预热时间 右键 start

如何读懂以太坊源代码

以下是一个学习以太坊源代码的清晰思路&#xff1a; 一、前期准备 基础知识储备&#xff1a; 熟悉区块链的基本概念&#xff0c;如区块、交易、共识机制、哈希函数等。了解 Go 语言&#xff08;以太坊主要使用 Go 语言编写&#xff09;的基本语法和编程概念。 工具准备&#x…

Unity(2022.3.41LTS) - UI详细介绍-Scroll View(滚动视图)

目录 零.简介 一、基本功能与用途 二、主要组件 Rect Transform&#xff08;矩形变换&#xff09;&#xff1a; Scroll Rect&#xff08;滚动矩形&#xff09;组件&#xff1a; Scrollbar&#xff08;滚动条&#xff09;组件&#xff1a; Mask&#xff08;遮罩&#xff…

今天来聊一聊前端框架有哪些呢? 主流Vue和React

使用工具&#xff1a; 联网搜索 前端框架主要包括React.js、Vue.js、Angular等。在现代网络技术的快速发展中&#xff0c;前端框架成为了实现界面美观、交互性强、用户体验佳的网页和应用不可或缺的工具。下面将具体介绍几款目前主流的前端框架&#xff1a; React.js 简介&…