利用卷积神经网络进行手写数字的识别

数据集介绍

MNIST(Modified National Institute of Standards and Technology)数据集是一个广泛使用的手写数字识别数据集,常用于机器学习和计算机视觉领域中的分类任务。它包含了从0到9的手写数字样本,常用于训练和测试各种图像分类算法。

数据集概况

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本是一张28×28像素的灰度图像,表示一个手写数字。每个图像是一个二维矩阵,像素值范围从0(黑色)到255(白色),灰度值表示不同的颜色深度。数据集中的标签是这些图像对应的数字(0-9)。

数据集格式

  • 训练集:60,000个图像,每个图像有一个对应的标签(0到9之间的数字)。
  • 测试集:10,000个图像,也有对应的标签。

使用场景

  1. 图像分类任务:由于数据集较小且标准化,MNIST是机器学习算法(尤其是深度学习模型)测试和比较性能的一个标准数据集。
  2. 模型性能评估:MNIST被广泛用于评估各种机器学习模型的效果,尤其是在图像处理领域。
  3. 教学:由于其简单性,MNIST常作为入门学习机器学习和神经网络的教学材料。

特点

  • 图像尺寸固定:28×28像素,适合用作标准输入。
  • 图像内容简单:大多数手写数字都是规范且易于分辨的。
  • 数据集较小,适合于快速实验和初步的模型验证。

数据集获取

MNIST数据集可以通过多个平台获取,例如:

  • 通过TensorFlow、PyTorch等框架的内建API加载。
  • 从MNIST官网下载。

数据预处理及参数选择

数据处理

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# softmax归一化指数函数(https://blog.csdn.net/lz_peter/article/details/84574716),其中0.1307是mean均值和0.3081是std标准差train_dataset = datasets.MNIST(root='./data/mnist', train=True, transform=transform)  # 本地没有就加上download=True
test_dataset = datasets.MNIST(root='./data/mnist', train=False, transform=transform)  # train=True训练集,=False测试集
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

参数的选择

batch_size = 64                #每个批次大小中有64个样本
learning_rate = 0.01           #学习率
momentum = 0.5                 #梯度下降冲量
epochs = 10                    #训练轮数
  • batch_size = 64:每次训练时使用64个样本来计算梯度并更新权重。
  • learning_rate = 0.01:每次权重更新时,步长为0.01,影响训练速度和稳定性。
  • momentum = 0.5:通过加权平均过去的梯度,帮助加速收敛并减少梯度更新的震荡。
  • epochs = 10:模型将在训练数据上进行10次完整的迭代,通常可以在这个范围内找到适合的训练状态。

网络模型

class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 10, kernel_size=5),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2),)self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, kernel_size=5),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2),)self.conv3 = torch.nn.Sequential(torch.nn.Flatten(),torch.nn.Linear(320, 50),torch.nn.Linear(50, 10),)def forward(self, x):batch_size = x.size(0)x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)x = self.conv2(x)  # 再来一次x = self.conv3(x)return x  # 最后输出的是维度为10的,也就是(对应数学符号的0~9)
  • 输入层:

    • 输入尺寸:每张输入图像是28x28像素的灰度图,单通道。输入张量的形状为 (batch_size, 1, 28, 28),其中 batch_size 是一次处理的图像数量,1 是表示单通道的灰度图像,28x28 是图像的尺寸。
  • 第一层卷积层(conv1):

    • 卷积层:使用一个大小为 5x5 的卷积核,将输入图像的1个通道(灰度)转换为10个通道。卷积核的步幅为1,填充为0(即没有边缘扩展)。这会产生一个大小为 24x24 的特征图(由于没有填充,尺寸会减少)。
    • 激活函数:ReLU(Rectified Linear Unit),它会对每个像素值进行非线性转换(ReLU(x) = max(0, x)),有效地引入了非线性特性。
    • 池化层:最大池化层使用 2x2 的池化窗口和步幅为2。池化操作减少了特征图的尺寸,将每个 2x2 的区域映射为最大值。池化操作将图像尺寸减半,从 24x24 减小为 12x12,同时减少计算量。
  • 第二层卷积层(conv2):

    • 卷积层:卷积核的大小为 5x5,将前一层输出的10个通道转换为20个通道。同样,步幅为1,没有填充。这个操作将特征图的大小从 12x12 减少到 8x8
    • 激活函数:使用ReLU激活函数。
    • 池化层:再次使用最大池化,池化窗口为 2x2,步幅为2。此操作将尺寸从 8x8 减小为 4x4
  • 全连接层(conv3):

    • 展平操作(Flatten):经过两层卷积和池化操作后,输出特征图的大小为 20x4x4。在传入全连接层之前,需要将这个多维的张量展平成一维向量。展平后的尺寸是 320(即 20 * 4 * 4)。
    • 第一个全连接层:将展平后的320个元素映射到50个神经元。该层的作用是通过加权和偏置的线性变换对输入进行处理,并通过激活函数进行非线性转换。
    • 第二个全连接层:将50个神经元映射到10个神经元,输出的每个神经元代表一个数字类别(0到9)。
  • 输出层:

    • 输出尺寸:最终输出为一个10维的向量,其中每个值表示输入图像属于每个类别的“分数”。这个分数可以通过softmax层转化为概率,用于多类分类任务。

模型训练

# Construct loss and optimizer ------------------------------------------------------------------------------
loss_f = torch.nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)  # lr学习率,momentum冲量# Train and Test CLASS --------------------------------------------------------------------------------------
# 把单独的一轮一环封装在函数类里
def train(epoch):running_loss = 0.0  # 这整个epoch的loss清零running_total = 0running_correct = 0for batch_idx, data in enumerate(train_loader, 0):  #第一个代表训练的批次,data中包括数据和标签,第一个数据代表输入即inputs,第二个数据代表标签labelsinputs, target = dataoptimizer.zero_grad()   #将之前的梯度清零# forward + backward + updateoutputs = model(inputs)loss = loss_f(outputs, target)#反向传播loss.backward()#参数更新optimizer.step()# 把运行中的loss累加起来,为了下面300次一除running_loss += loss.item()# 把运行中的准确率acc算出来_, predicted = torch.max(outputs.data, dim=1)running_total += inputs.shape[0]running_correct += (predicted == target).sum().item()if batch_idx % 300 == 299:  # 不想要每一次都出loss,浪费时间,选择每300次出一个平均损失,和准确率print('[%d, %5d]: loss: %.3f , acc: %.2f %%'% (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))running_loss = 0.0  # 这小批300的loss清零running_total = 0running_correct = 0  ## 这小批300的acc清零# torch.save(model.state_dict(), './model_Mnist.pth')# torch.save(optimizer.state_dict(), './optimizer_Mnist.pth')def test():correct = 0total = 0with torch.no_grad():  # 测试集不用算梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度,沿着行(第1个维度)去找1.最大值和2.最大值的下标total += labels.size(0)  # 张量之间的比较运算correct += (predicted == labels).sum().item()acc = correct / totalprint('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch + 1, epochs, 100 * acc))  # 求测试的准确率,正确数/总数return acc# Start train and Test --------------------------------------------------------------------------------------
if __name__ == '__main__':acc_list_test = []for epoch in range(epochs):train(epoch)# if epoch % 10 == 9:  #每训练10轮 测试1次acc_test = test()acc_list_test.append(acc_test)plt.plot(acc_list_test)plt.xlabel('Epoch')plt.ylabel('Accuracy On TestSet')plt.show()

训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/889527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Transformer入门(6)Transformer编码器的前馈网络、加法和归一化模块

文章目录 7.前馈网络8.加法和归一化组件9.组合所有编码器组件构成完整编码器 7.前馈网络 编码器块中的前馈网络子层如下图所示: 图1.32 – 编码器块 前馈网络由两个带有ReLU激活函数的全连接层组成。全连接层(Fully Connected Layer)有时也…

前端(async 和await)

1 async async 将 function 变为成为 async 函数 ●async 内部可以使用 await,也可以不使用,因此执行这个函数时,可以使用 then 和 catch 方法 ●async 函数的返回值是一个 Promise 对象 ●Promise 对象的结果由 async 函数执行的返回值决…

Java-25 深入浅出 Spring - 实现简易Ioc-01 Servlet介绍 基本代码编写

点一下关注吧!!!非常感谢!!持续更新!!! 大数据篇正在更新!https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了: MyBatis&#xff…

H.323音视频协议

概述 H.323是国际电信联盟(ITU)的一个标准协议栈,该协议栈是一个有机的整体,根据功能可以将其分为四类协议,也就是说该协议从系统的总体框架(H.323)、视频编解码(H.263)、…

WPF+MVVM案例实战与特效(四十)- 一个动态流水边框的实现

文章目录 1、运行效果2、案例实现1、PointAnimationUsingKeyFrames 关键帧动画2、矩形流水边框案例2、运行效果3、关键技术点3、案例拓展:其他形状实现1、圆形流水边框2、心形流水边3、完整页面代码4、运行效果5、总结1、运行效果 2、案例实现 1、PointAnimationUsingKeyFram…

微信小程序--创建一个日历组件

微信小程序–创建一个日历组件 可以创建一个日历组件&#xff0c;来展示当前月份的日期&#xff0c;并支持切换月份的功能。 一、目录结构 /pages/calendarcalendar.wxmlcalendar.scsscalendar.jscalendar.json二、calendar.wxml <view class"calendar"><…

【Linux-ubuntu通过USB传输程序点亮LED灯】

Linux-ubuntu通过USB传输程序点亮LED灯 一,初始化GPIO配置1.使能时钟2.其他寄存器配置 二&#xff0c;程序编译三&#xff0c;USB传输程序 一,初始化GPIO配置 1.使能时钟 使能就是一个控制信号&#xff0c;用于决定时钟信号是否能够有效的传递或者被使用&#xff0c;就像一个…

Rust之抽空学习系列(三)—— 编程通用概念(中)

Rust之抽空学习系列&#xff08;三&#xff09;—— 编程通用概念&#xff08;中&#xff09; 1、变量&可变性 在Rust中&#xff0c;变量默认是不可变的 fn main() {let x 5;println!("x is {}", x); }使用let来声明一个变量&#xff0c;此时变量默认是不可变…

Mybatis---事务

目录 引入 一、事务存在的意义 1.事务是什么&#xff1f; 2.Mybatis关于事务的管理 程序员自己控制处理的提交和回滚 引入 一、事务存在的意义 1.事务是什么&#xff1f; 多个操作同时进行,那么同时成功&#xff0c;那么同时失败。这就是事务。 事务有四个特性&#xf…

<项目代码>YOLOv8 车牌识别<目标检测>

项目代码下载链接 &#xff1c;项目代码&#xff1e;YOLOv8 车牌识别&#xff1c;目标检测&#xff1e;https://download.csdn.net/download/qq_53332949/90121387YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题…

跨平台开发技术的探索:从 JavaScript 到 Flutter

随着多平台支持和用户体验一致性在应用程序开发中变得越来越重要,开发者面临的挑战是如何在不同平台上保持代码的可维护性和高效性。本文将探讨如何利用现代技术栈,包括 Flutter、JavaScript、HTML5、WebAssembly、TypeScript 和 Svelte,在统一的平台上进行高效的跨平台开发…

华为eNSP:VRRP

一、VRRP背景概述 在现代网络环境中&#xff0c;主机通常通过默认网关进行网络通信。当默认网关出现故障时&#xff0c;网络通信会中断&#xff0c;影响业务连续性和稳定性。为了提高网络的可靠性和冗余性&#xff0c;采用虚拟路由冗余协议&#xff08;VRRP&#xff09;是一种…

Referer头部在网站反爬虫技术中的运用

网站数据的安全性和完整性至关重要。爬虫技术&#xff0c;虽然在数据收集和分析中发挥着重要作用&#xff0c;但也给网站管理员带来了挑战。为了保护网站数据不被恶意爬取&#xff0c;反爬虫技术应运而生。本文将探讨HTTP头部中的Referer字段在反爬虫技术中的应用&#xff0c;并…

【ArcGIS微课1000例】0135:自动生成标识码(长度不变,前面自动加0)

文章目录 一、加载实验数据二、BSM计算方法一、加载实验数据 加载专栏《ArcGIS微课实验1000例(附数据)》配套数据中0135.rar中的建筑物数据,如下图所示: 打开属性表,BSM为数据库中要求的字段:以TD_T 1066-2021《不动产登记数据库标准》为例: 计算出来的BSM如下图: 二、B…

NVR小程序接入平台/设备EasyNVR深度解析H.265与H.264编码视频接入的区别

随着科技的飞速发展和社会的不断进步&#xff0c;视频压缩编码技术已经成为视频传输和存储中不可或缺的一部分。在众多编码标准中&#xff0c;H.265和H.264是最为重要的两种。今天我们来将深入分析H.265与H.264编码的区别。 一、H.265与H.264编码的区别 1、比特率与分辨率 H.…

华硕奥创软件在线安装和离线安装方法

华硕奥创软件在线安装和离线安装方法 1. 华硕奥创软件介绍2. 华硕奥创软件在线安装2.1 第一种2.2 第二种 3. 华硕奥创软件离线安装3.1 概述3.2 华硕奥创软件离线包下载方式 4. 卸载华硕奥创软件4.1 概述4.2 华硕奥创卸载软件下载与使用方式 结束语 1. 华硕奥创软件介绍 华硕奥…

minio 分布式文件管理

一、minio 是什么&#xff1f; MinIO构建分布式文件系统&#xff0c;MinIO 是一个非常轻量的服务,可以很简单的和其他应用的结合使用&#xff0c;它兼容亚马逊 S3 云存储服务接口&#xff0c;非常适合于存储大容量非结构化的数据&#xff0c;例如图片、视频、日志文件、备份数…

A6688 JSP+MYSQL+LW+二手物品网上交易系统

二手物品网上交易系统的设计与实现 1.摘要2.开发目的和意义3.系统功能设计4.系统界面截图5.源码获取 1.摘要 摘 要 随着社会经济快速发展&#xff0c;互联网推动了电子商务业的迅速崛起。越来越多的人们喜欢在线进行商品的交易&#xff0c;尤其是对于二手物品的处理&#xff0…

算法分析与设计之分治算法

文章目录 前言一、分治算法divide and conquer1.1 分治定义1.2 分治法的复杂性分析&#xff1a;递归方程1.2.1 主定理1.2.2 递归树法1.2.3 迭代法 二、典型例题2.1 Mergesort2.2 Counting Inversions2.3 棋盘覆盖2.4 最大和数组2.5 Closest Pair of Points2.6 Karatsuba算法&am…

Ubuntu 安装 Samba Server

在 Mac 上如何能够与Ubuntu 服务器共享文件夹&#xff0c;需要在 Ubuntu 上安装 Samba 文件服务器。本文将介绍如何在 Ubuntu 上安装 Samba 服务器从而达到以下目的&#xff1a; Mac 与 Ubuntu 共享文件通过用户名密码访问 安装 Samba 服务 sudo apt install samba修改配置文…