【PyTorch】使用PyTorch创建卷积神经网络并在CIFAR-10数据集上进行分类

前言

在深度学习的世界中,图像分类任务是一个经典的问题,它涉及到识别给定图像中的对象类别。CIFAR-10数据集是一个常用的基准数据集,包含了10个类别的60000张32x32彩色图像。在本博客中,我们将探讨如何使用PyTorch框架创建一个简单的卷积神经网络(CNN)来对CIFAR-10数据集中的图像进行分类。

在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

引用

关于卷积神经网络的原理,感兴趣的请参阅我的另一篇博客,里面只使用numpy和基础函数组建了一个卷积神经网络模型,并完成训练和测试
【手搓深度学习算法】从头创建卷积神经网络

背景

卷积神经网络是深度学习中用于图像识别和分类的一种强大工具。它们能够自动从图像中提取特征,并通过一系列卷积层、池化层和全连接层来学习图像的复杂模式。

CIFAR-10数据集包含了飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车等10个类别的图像。每个类别有6000张图像,其中50000张用于训练,10000张用于测试。
请添加图片描述

代码解析

我们的目标是构建一个能够处理CIFAR-10数据集的CNN模型。以下是我们的模型结构和数据处理流程的简要概述:

数据预处理

我们首先定义了unpickle函数来加载CIFAR-10数据集的批次文件。read_data函数用于读取数据,将其转换为适合卷积网络输入的格式,并进行归一化处理。我们还提供了一个选项来将图像转换为灰度。

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dictdef read_data(file_path, gray = False, percent = 0, normalize = True):data_src = unpickle(file_path)np_data = np.array(data_src["data".encode()]).astype("float32")np_labels = np.array(data_src["labels".encode()]).astype("float32").reshape(-1,1)single_data_length = 32*32 image_ret = Noneif (gray):np_data = (np_data[:, :single_data_length] + np_data[:, single_data_length:(2*single_data_length)] + np_data[:, 2*single_data_length : 3*single_data_length])/3image_ret = np_data.reshape(len(np_data),32,32)else:image_ret = np_data.reshape(len(np_data),32,32,3)if(normalize):mean = np.mean(np_data)std = np.std(np_data)np_data = (np_data - mean) / stdif (percent != 0):np_data = np_data[:int(len(np_data)*percent)]np_labels = np_labels[:int(len(np_labels)*percent)]image_ret = image_ret[:int(len(image_ret)*percent)]num_classes = len(np.unique(np_labels))np_data, np_labels = convert_to_conv_input(np_data, np_labels)return np_data, np_labels, num_classes, image_ret 

网络结构

Conv类定义了我们的CNN模型,它包含一个卷积层、一个最大池化层、一个ReLU激活函数和一个全连接层。在forward方法中,我们指定了数据通过网络的流程。

class Conv(th.nn.Module):def __init__(self, *args, **kwargs) -> None:super(Conv, self).__init__()self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)self.pool = th.nn.MaxPool2d(kernel_size=2,stride=2)self.relu = th.nn.ReLU()self.linear = th.nn.Linear(16*15*15, 10)self.softmax = th.nn.Softmax(dim=1)def forward(self, x):x = self.conv(x) #32,16,30,30x = self.pool(x) #32,16,15,15x = self.relu(x)x = x.view(x.size(0), -1)x = self.linear(x)return x# 在predict函数中,额外调用了softmax,将线性层的10个特征值转化为概率,在前向传播中不用是因为pytorch中交叉熵函数自带了softmaxdef predict(self,x):x = self.forward(x)x = self.softmax(x)return x
卷积层、池化层、线性层的输入特征数量的计算方法

线性层的输入特征个数取决于前面层的输出。
具体来说,线性层的输入特征个数是卷积层和池化层处理后的输出特征图的总元素数量。

卷积层定义如下:

self.conv = th.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)

这里,in_channels=3 表示输入图像有3个颜色通道(RGB),out_channels=16 表示卷积层将输出16个特征图。

接下来是池化层:

self.pool = th.nn.MaxPool2d(kernel_size=2, stride=2)

kernel_size=2,表示池化窗口的大小是2x2。stride=2 表示池化操作的步长是2。

为了计算线性层的输入特征个数,我们需要知道卷积层和池化层之后的输出特征图的大小。这可以通过计算公式得到,或者通过在实际数据上运行网络的前向传播来确定。

计算公式如下:

对于卷积层,输出特征图的大小可以通过以下公式计算:

H_out = (H_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
W_out = (W_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1

对于池化层,输出特征图的大小也可以通过类似的公式计算。

由于没有指定paddingdilation,查看函数定义可知它们的默认值分别是0和1。因此,如果输入图像的大小是32x32,卷积层之后的大小将是:

H_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30
W_out = (32 - 1 * (3 - 1) - 1) / 1 + 1 = 30

因此,卷积层的输出将有16个30x30的特征图。

然后,池化层将这些特征图的大小减半(因为kernel_size=2stride=2),所以输出将是16个15x15的特征图。

最后,线性层的输入特征个数将是这些特征图的总元素数量:

num_features = out_channels * H_out_pool * W_out_pool = 16 * 15 * 15 = 3600

因此,线性层的正确定义应该是:

self.linear = th.nn.Linear(3600, num_classes)

训练过程

main函数中,我们初始化了模型、损失函数和优化器。我们使用随机梯度下降(SGD)作为优化算法,并设置了学习率。接着,我们进入了训练循环,其中包括前向传播、损失计算、反向传播和权重更新。

loss_function = th.nn.CrossEntropyLoss()
optimizer = th.optim.SGD(conv_model.parameters(), lr = lr)

测试和评估

训练完成后,我们使用训练好的模型对测试数据进行评估,并计算准确率。我们还提供了一个predict方法,它在给定输入数据后返回模型的预测概率。

def predict(self,x):x = self.forward(x)x = self.softmax(x)return x
softmax激活函数

Softmax 激活函数是一种广泛使用的函数,它将一个实数向量转换为概率分布。在深度学习中,它常常用于多类别分类问题的输出层。

Softmax 函数的定义如下:

softmax ( z ) i = e z i ∑ j e z j \text{softmax}(z)_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}} softmax(z)i=jezjezi

其中 z z z 是输入向量, z i z_i zi z z z 的第 i i i 个元素, softmax ( z ) i \text{softmax}(z)_i softmax(z)i 是输出向量的第 i i i 个元素。

Softmax 函数的主要特性是它的输出是一个概率分布,即所有输出元素的值都在 ( 0 , 1 ) (0, 1) (0,1) 区间内,且所有输出元素的值之和为 1。这使得 Softmax 函数非常适合用于表示概率。

Softmax 函数的一个重要性质是它是连续的,且其导数容易计算。这使得 Softmax 函数在深度学习中的反向传播过程中非常有用。

Softmax 函数的导数如下:

∂ ∂ z i softmax ( z ) i = softmax ( z ) i ( 1 − softmax ( z ) i ) \frac{\partial}{\partial z_i}\text{softmax}(z)_i = \text{softmax}(z)_i(1 - \text{softmax}(z)_i) zisoftmax(z)i=softmax(z)i(1softmax(z)i)

这个导数表达式表明,对于 Softmax 函数的输出 y i y_i yi,其对输入 z i z_i zi 的导数等于 y i ( 1 − y i ) y_i(1 - y_i) yi(1yi)。这个导数表达式在反向传播过程中非常有用,因为它可以直接用于计算梯度。

训练过程中没有使用softmax层,是应为torch的交叉熵损失函数已经包含了softmax的操作,如果叠加使用,可能得到错误的结果。

运行结果

作为一个简单的卷积模型,在测试集上得到了60%的准确率
请添加图片描述

完整代码

本文不提供完整代码,因为随着我的微调优化过程,已经没有这个版本的基线代码了,想要最终代码的欢迎阅读下一篇博客 “记一次卷积网络调优的过程”
在这里插入图片描述

注意点

  • 数据预处理:确保数据被正确地加载和归一化,这对模型的训练效果至关重要。
  • 模型结构:模型的层数和参数需要根据任务的复杂性来调整。过于简单的模型可能无法捕捉到数据中的复杂特征,而过于复杂的模型可能会导致过拟合。
  • 损失函数:我们使用交叉熵损失函数,它适用于多类别分类问题。
  • 优化器:在每次迭代前,记得清除累积的梯度,以避免错误的梯度更新。

可能的优化点

  • 学习率调整:可以尝试使用学习率调度器来在训练过程中调整学习率,以改善模型的收敛速度和性能。
  • 权重初始化:尝试不同的权重初始化方法,以帮助模型更快地收敛。
  • 正则化技术:使用如Dropout、L2正则化等技术来减少过拟合。
  • 数据增强:通过对训练图像进行随机变换(如旋转、缩放、裁剪等),可以增加模型的泛化能力。
  • 更深的网络:考虑增加更多的卷积层和池化层来提取更复杂的特征。
  • 批量归一化:在卷积层之后添加批量归一化层,以稳定训练过程并加速收敛。

结论

通过本博客,我们展示了如何使用PyTorch框架构建一个简单的CNN模型,并在CIFAR-10数据集上进行训练和测试。虽然我们的模型结构相对简单,但它为理解深度学习和图像分类提供了一个很好的起点。在下一篇博客中,我们将尝试不断优化模型结构和训练过程,以达到更高的准确率和性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/649905.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#,打印漂亮杨辉三角形(帕斯卡三角形)的源代码

杨辉 Blaise Pascal 这是某些程序员看完会哭的代码。 杨辉三角形(Yanghui Triangle),是一种序列数值的三角形几何排列,最早出现于南宋数学家杨辉1261年所著的《详解九章算法》一书。 欧洲学者,最先由帕斯卡&#x…

Windows打开IE浏览器命令最简单的方法

问题场景: 许多插件或特定版本的系统需要使用ie浏览器来访问,window默认的ie浏览器是被禁用的如何快速打开ie浏览器解决问题 目录 问题场景: 测试环境: 检查环境是否支持: 问题解决: 方法一 方法二 方法…

03 SB实战 -微头条之首页门户模块(跳转某页面自动展示所有信息+根据hid查询文章全文并用乐观锁修改阅读量)

1.1 自动展示所有信息 需求描述: 进入新闻首页portal/findAllType, 自动返回所有栏目名称和id 接口描述 url地址:portal/findAllTypes 请求方式:get 请求参数:无 响应数据: 成功 {"code":"200","mes…

hex 尽然可以 设置透明度,透明度参数对比图 已解决

还不知道CSS Color Module Level 4标准早在2014年就推出8位hex和4位hex来支持设置alpha值,以实现hex和rgba的互转。这个办法可比6位HEX转RGBA简洁多了,先来简单解释一下: 8位hex是在6位hex基础上加后两位来表示alpha值,00表示完全…

Hadoop-MapReduce-MRAppMaster启动篇

一、源码下载 下面是hadoop官方源码下载地址&#xff0c;我下载的是hadoop-3.2.4&#xff0c;那就一起来看下吧 Index of /dist/hadoop/core 二、上下文 在上一篇<Hadoop-MapReduce-源码跟读-客户端篇>中已经将到&#xff1a;作业提交到ResourceManager&#xff0c;那…

数据结构——树的合集

目录 文章目录 前言 一.树的表达方式 1.树的概念 2.树的结点 3.树的存储结构 01.双亲表示法 顺序表示形式 优缺点说明 02.孩子表示法 03.孩子兄弟表示法 04.非类存储代码演示 二.二叉树 1.树的特点 2.二叉树 01.定义 02.二叉树的性质 03.满二叉树 04.完全二叉树…

uniapp封装公共的方法或者数据请求方法

仅供自己参考&#xff0c;不是每个页面都用到这个方法&#xff0c;所以我直接在用到的页面引用该公用方法&#xff1a; 1、新建一个util.js文件 export const address function(options){return new Promise((resolve,reject)>{uni.request({url:"https://x.cxniu.…

Istio-gateway

一. gateway 在 Kubernetes 环境中&#xff0c;Kubernetes Ingress用于配置需要在集群外部公开的服务。但是在 Istio 服务网格中&#xff0c;更好的方法是使用新的配置模型&#xff0c;即 Istio Gateway&#xff0c;Gateway 允许将 Istio 流量管理的功能应用于进入集群的流量&…

Android P 背光机制流程分析

在android 9.0中&#xff0c;相比android 8.1而言&#xff0c;背光部分逻辑有较大的调整&#xff0c;这里就对android P背光机制进行完整的分析。 1.手动调节亮度 1.1.在SystemUI、Settings中手动调节 在界面(SystemUI)和Settings中拖动进度条调节亮度时&#xff0c;调节入口…

Excel 2019 for Mac/Win:商务数据分析与处理的终极工具

在当今快节奏的商业环境中&#xff0c;数据分析已经成为一项至关重要的技能。从市场趋势预测到财务报告&#xff0c;再到项目管理&#xff0c;数据无处不在。而作为数据分析的基石&#xff0c;Microsoft Excel 2019 for Mac/Win正是一个强大的工具&#xff0c;帮助用户高效地处…

face_recognition和图像处理中left、top、right、bottom解释

face_recognition.face_locations 介绍 加载图像文件后直接调用face_recognition.face_locations(image)&#xff0c;能定位所有图像中识别出的人脸位置信息&#xff0c;返回值是列表形式&#xff0c;列表中每一行是一张人脸的位置信息&#xff0c;包括[top, right, bottom, l…

微服务-微服务Alibaba-Nacos注册中心实现

1. 系统架构的演变 俗话说&#xff0c; 没有最好的架构&#xff0c;只有最合适的架构。 微服务架构也是随着信息产业的发展而出现的最有普 遍适用性的一套架构模式。通常来说&#xff0c;我们认为架构发展历史经历了这样一个过程&#xff1a;单体架构——> 垂直架构 ——&g…

go实现生成html文件和html文件浏览服务

文章目录 本文章是为了解决 使用Jenkins执行TestNgSeleniumJsoup自动化测试和生成ExtentReport测试报告生成的测试报告&#xff0c;只能在jenkins里面访问&#xff0c;为了方便项目组内所有人员都能查看测试报&#xff0c;可以在jenkins构建时&#xff0c;把测试报告的html推送…

Leetcode—114. 二叉树展开为链表【中等】

2023每日刷题&#xff08;九十八&#xff09; Leetcode—114. 二叉树展开为链表 Morris-like算法思想 可以发现展开的顺序其实就是二叉树的先序遍历。算法和 94 题中序遍历的 Morris 算法有些神似&#xff0c;我们需要两步完成这道题。 将左子树插入到右子树的地方将原来的右…

PreNorm和PostNorm对比

要点总结 标准的Transformer使用的是PostNorm 在完全相同的训练设置下Pre Norm的效果要优于Post Norm&#xff0c;这只能显示出Pre Norm更容易训练&#xff0c;因为Post Norm要达到自己的最优效果&#xff0c;不能用跟Pre Norm一样的训练配置&#xff08;比如Pre Norm可以不加…

第14次修改了可删除可持久保存的前端html备忘录:增加一个翻牌钟,修改背景主题:现代深色

第14次修改了可删除可持久保存的前端html备忘录&#xff1a;增加一个翻牌钟&#xff0c;修改背景主题&#xff1a;现代深色 备忘录代码 <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta http-equiv"X…

网络安全防御保护实验(二)

一、登录进防火墙的web控制页面进行配置安全策略 登录到Web控制页面&#xff1a; 打开Web浏览器&#xff0c;输入防火墙的IP地址或主机名&#xff0c;然后使用正确的用户名和密码登录到防火墙的Web管理界面。通常&#xff0c;这些信息在防火墙设备的文档或设备上会有说明。 导…

鸿蒙ArkUI开发-应用添加弹窗

在我们日常使用应用的时候&#xff0c;可能会进行一些敏感的操作&#xff0c;比如删除联系人&#xff0c;这时候我们给应用添加弹窗来提示用户是否需要执行该操作&#xff0c;如下图所示&#xff1a; 弹窗是一种模态窗口&#xff0c;通常用来展示用户当前需要的或用户必须关注的…

C++知识点笔记

二维数组 定义方式&#xff1a; 1、数据类型 数组名[行数][列数]; 2、数据类型 数组名[行数][列数]{{数据1,数据2},{数据3,数据4}}; 3、数据类型 数组名[行数][列数]{数据1,数据2,数据3,数据4}; 4、数据类型 数组名[][列数]{数据1,数据2,数据3,数据4}; 建议&#xff1a;以…

React中使用LazyBuilder实现页面懒加载方法二

前言&#xff1a; 在一个表格中&#xff0c;需要展示100条数据&#xff0c;当每条数据里面需要承载的内容很多&#xff0c;需要渲染的元素也很多的时候&#xff0c;容易造成页面加载的速度很慢&#xff0c;不能给用户提供很好的体验时&#xff0c;懒加载是优化页面加载速度的方…