PyTorch 实现手写数字识别

PyTorch 实现手写数字识别

在本教程中,我们将使用 PyTorch 实现经典的手写数字识别任务。我们将使用 MNIST 数据集,这是一个包含手写数字的图像数据集。我们将介绍如何使用 PyTorch 构建、训练和评估一个简单的卷积神经网络(CNN)模型来进行手写数字识别。

1. 项目概述

手写数字识别任务是通过训练模型,让其能够识别手写数字图像并输出正确的数字类别(0-9)。MNIST 数据集包含 28x28 像素的灰度图像,每个图像代表一个手写数字。

我们将使用以下步骤:

  1. 加载 MNIST 数据集
  2. 构建一个卷积神经网络(CNN)
  3. 训练模型
  4. 评估模型性能
  5. 进行测试预测

2. 官方文档链接

  • PyTorch 官方文档
  • MNIST 数据集链接

3. 安装 PyTorch 和依赖库

首先,确保您已经安装了 PyTorch 和相关依赖库。如果没有安装,可以运行以下命令:

pip install torch torchvision matplotlib

4. 加载 MNIST 数据集

我们将使用 torchvision 提供的 MNIST 数据集。它包含 60,000 个训练样本和 10,000 个测试样本。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载 MNIST 训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 加载数据集
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 查看数据集的大小
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")# 可视化部分样本
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0], cmap='gray')plt.title(f"Label: {example_targets[i]}")plt.axis('off')
plt.show()

说明

  • transforms.Compose:我们将图像转换为 PyTorch 张量,并将像素值标准化为 [-1, 1] 的范围。
  • DataLoader:用于将数据集加载为批次,并打乱数据顺序以便训练时使用。

5. 构建卷积神经网络(CNN)

我们将构建一个简单的 CNN 模型,用于手写数字识别。该模型将包含两个卷积层和两个全连接层。

import torch.nn as nn
import torch.nn.functional as F# 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入通道为1(灰度图),输出通道为16,卷积核大小为3x3self.conv1 = nn.Conv2d(1, 16, kernel_size=3)# 卷积层2: 输入通道为16,输出通道为32,卷积核大小为3x3self.conv2 = nn.Conv2d(16, 32, kernel_size=3)# 全连接层1: 输入为32*5*5(展平后的特征图),输出为128self.fc1 = nn.Linear(32 * 5 * 5, 128)# 全连接层2: 输入为128,输出为10(10个类别)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积层 + ReLU + 最大池化层x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))# 展平成一维向量x = x.view(-1, 32 * 5 * 5)# 全连接层 + ReLUx = F.relu(self.fc1(x))# 输出层x = self.fc2(x)return x# 实例化模型
model = CNN()
print(model)

说明

  • conv1conv2:卷积层用于提取图像特征。第一个卷积层从 1 个输入通道(灰度图像)转换为 16 个特征图,第二个卷积层将 16 个特征图转换为 32 个特征图。
  • max_pool2d:最大池化层,用于下采样特征图,将特征图尺寸减半。
  • fc1fc2:全连接层,用于将卷积层提取到的特征进行分类。

6. 训练模型

我们将定义损失函数和优化器,然后在训练数据集上训练模型。

import torch.optim as optim# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 将模型移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")print("训练完成!")

说明

  • CrossEntropyLoss:用于分类任务的损失函数,适用于多分类问题。
  • optimizer:使用 Adam 优化器,能够自动调整学习率并加快收敛速度。
  • 训练过程包括前向传播、损失计算、反向传播和参数更新。

7. 评估模型性能

在训练完成后,我们将使用测试数据集来评估模型的性能,计算模型在测试集上的准确率。

# 测试模型
model.eval()  # 切换到评估模式
correct = 0
total = 0with torch.no_grad():  # 关闭梯度计算for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集上的准确率: {100 * correct / total:.2f}%')

说明

  • model.eval():在评估模型时关闭 dropout 和 batch normalization。
  • torch.no_grad():关闭梯度计算以提高测试阶段的效率。

8. 进行预测

最后,我们可以使用训练好的模型对手写数字图像进行预测。

# 从测试集中取出一个样本
example_data, example_target = next(iter(test_loader))
example_data = example_data.to(device)# 使用模型进行预测
model.eval()
with torch.no_grad():output = model(example_data)# 可视化预测结果
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0].cpu(), cmap='gray')plt.title(f"预测: {torch.argmax(output[i]).item()}")plt.axis('off')
plt.show()

说明

  • 取出测试集中的一批样本进行预测,并可视化模型的预测结果。

9. 总结

在本教程中,我们使用 PyTorch 实现了手写数字识别任务,构建了一个简单的卷积神经网络(CNN),并在 MNIST 数据集上进行了训练和评估。通过此项目,您可以了解如何加载数据、构建模型、训练、评估和测试 PyTorch 模型。

10. 改进方向

  • 增加网络深度:可以增加卷积层和全连接层的

数量,提高模型的表现。

  • 使用数据增强:通过数据增强技术(旋转、缩放等),可以提高模型的泛化能力。
  • 应用在其他数据集:除了 MNIST,还可以将模型应用到其他数据集,如 FashionMNIST、CIFAR-10 等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54391.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java之斗地主部分功能的实现

今天我们要实现斗地主中发牌和洗牌这两个功能,该如何去实现呢? 1.创建牌类:52张牌每一张牌包含两个属性:牌的大小和牌的花色。 故我们优先创建一个牌的类(Card):包含大小和花色。 public class Card { //单张牌的大小及类型/…

无人机+自组网:中继通信增强技术详解

无人机与自组网技术的结合,特别是通过中继通信增强技术,为无人机在复杂环境中的通信提供了稳定、高效、可靠的解决方案。以下是对该技术的详细解析: 一、无人机自组网技术概述 无人机自组网技术是一种利用无人机作为节点,通过无…

proteus仿真学习(1)

一,创建工程 一般选择默认模式,不配置pcb文件 可以选用芯片型号也可以不选 不选则从零开始布局,没有初始最小系统。选用则有初始最小系统以及基础的main函数 本次学习使用从零开始,不配置固件 二,上手软件 1.在元件…

6--SpringBootWeb案例(详解)

目录 环境搭建 部门管理 查询部门 接口文档 代码 删除部门 接口文档 代码 新增部门 接口文档 代码 已有前端,根据接口文档完成后端功能的开发 成品如下: 环境搭建 1. 准备数据库表 (dept 、 emp) -- 部门管理 create table dept( id int un…

如何有效检测住宅IP真伪?

在当今的互联网时代,住宅IP(即家庭用户通过宽带服务提供商获得的IP地址)在跨境电商、广告投放、网络安全等多个领域扮演着重要角色。然而,随着网络环境的复杂化和欺诈行为的增多,如何有效检测和辨别住宅IP的真伪成为了…

Spring:统一结果私有属性造成的前端无法访问异常报错问题

用户未填写任何评价 1.问题复现 (1)看一段代码 controller: import lombok.extern.slf4j.Slf4j; import org.ljy.testdemo.common.Result; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.w…

Centos中关闭swap分区,关闭内存交换

概述: Swap 分区是 Linux 系统中扩展物理内存的一种机制。Swap的主要功能是当全部的RAM被占用并需要更多内存时,用磁盘空间代理RAM内存。Swap对虚拟化技术资源损耗非常大,一般虚拟化是不允许开启交换空间的,如果不关闭Swap&…

【Linux课程学习】make/Makefile:Linux项目自动化构建工具

🎁个人主页:我们的五年 🔍系列专栏:Linux课程学习 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 🍉一.make/Makefile的理解: …

关于STM32项目面试题02:ADC与DAC篇(输入部分NTC、AV:0-5V、AI:4-20mA和DAC的两个引脚)

博客的风格是:答案一定不能在问题的后面,要自己想、自己背;回答都是最精简、最精简、最精简,可能就几个字,你要自己自信的展开。 面试官01:什么是模数转换/ADC?说说模数转换的流程? …

基于SpringBoot+Vue+MySQL的手机销售管理系统

系统展示 用户前台界面 管理员后台界面 商家后台界面 系统背景 随着智能手机的普及和市场竞争的日益激烈,手机销售行业面临着前所未有的挑战与机遇。传统的手工记录和简单的电子表格管理方式已难以满足现代手机销售业务的需求,销售数据的混乱和管理效率低…

技术成神之路:设计模式(十四)享元模式

介绍 享元模式(Flyweight Pattern)是一种结构性设计模式,旨在通过共享对象来有效地支持大量细粒度的对象。 1.定义 享元模式通过将对象状态分为内部状态(可以共享)和外部状态(不可共享)&#xf…

C语言-文件操作-一些我想到的、见到的奇怪的问题

博客主页:【夜泉_ly】 本文专栏:【C语言】 欢迎点赞👍收藏⭐关注❤️ C语言-文件操作-一些我想到的、见到的奇怪的问题 前言1.在不关闭文件的情况下,连续多次调用 fopen() 打开同一个文件,会发生什么?1.1过…

Cursor火出圈,未来程序员还有出路吗?

大家好,我是凡人。 今天我表弟家邻居的阿姨,托他问问我目前程序员还有前景吗,希望我根据十几年的经验给出点建议,看看程序员这条路未来能不能走。 一下子不知道该怎么回复他了,如果是三年前问我,肯定毫不…

【React】React18.2.0核心源码解读

前言 本文使用 React18.2.0 的源码,如果想回退到某一版本执行git checkout tags/v18.2.0即可。如果打开源码发现js文件报ts类型错误请看本人另一篇文章:VsCode查看React源码全是类型报错如何解决。 阅读源码的过程: 下载源码 观察 package…

解决【WVP服务+ZLMediaKit媒体服务】加入海康摄像头后,能发现设备,播放/点播失败,提示推流超时!

环境介绍 每人搭建的环境不一样,情况不一样,但是原因都是下面几种: wvp配置不当网络端口未放开网络不通 我搭建的环境: WVP服务:windows下,用idea运行的源码 ZLM服务:虚拟机里 问题描述 1.…

【人工智能学习笔记】5 计算机视觉基础

计算机视觉概述 定义:计算机视觉(Computer Vision)是一门研究如何使机器“看”的科学,也可以看作是研究如何使人工系统从图像活多维数据中“感知”的科学终极目标:计算机视觉成为机器认知世界的基础,终极目…

superset 解决在 mac 电脑上发送 slack 通知的问题

参考文档: https://superset.apache.org/docs/configuration/alerts-reports/ 核心配置: FROM apache/superset:3.1.0USER rootRUN apt-get update && \apt-get install --no-install-recommends -y firefox-esrENV GECKODRIVER_VERSION0.29.0 RUN wget -q https://g…

【高级篇】ENC编码器如何挂载Windows共享目录进行录像

【高级篇】ENC编码器如何挂载Windows共享目录进行录像 Windows共享目录前提条件1、打开控制面板,点击 程序 菜单2、点击 启用或关闭Windows功能 菜单3、如下图,勾选SMB1.0/CIFS文件共享支持,并点击确认按钮,然后根据提示重启电脑 创建共享目录…

如何利用Samba跨平台分享Ubuntu文件夹

1.安装Samba 终端输入sudo apt install samba 2.配置Samba 终端输入sudo vim /etc/samba/smb.conf 打开配置文件 滑动文件到最底下 输入以下内容 [Share] # 要共享的文件夹路径 path /home/xxx/sambashare read only no browsable yes编辑完成后按一下Esc按键后输入:wq回…

ABAP-Swagger 一种公开 ABAP REST 服务的方法

ABAP-Swagger An approach to expose ABAP REST services 一种公开 ABAP REST 服务的方法 Usage 1: develop a class in ABAP with public methods 2: implement interface ZIF_SWAG_HANDLER, and register the public methods(example method zif_swag_handler~meta) 3: …