PyTorch 实现手写数字识别

PyTorch 实现手写数字识别

在本教程中,我们将使用 PyTorch 实现经典的手写数字识别任务。我们将使用 MNIST 数据集,这是一个包含手写数字的图像数据集。我们将介绍如何使用 PyTorch 构建、训练和评估一个简单的卷积神经网络(CNN)模型来进行手写数字识别。

1. 项目概述

手写数字识别任务是通过训练模型,让其能够识别手写数字图像并输出正确的数字类别(0-9)。MNIST 数据集包含 28x28 像素的灰度图像,每个图像代表一个手写数字。

我们将使用以下步骤:

  1. 加载 MNIST 数据集
  2. 构建一个卷积神经网络(CNN)
  3. 训练模型
  4. 评估模型性能
  5. 进行测试预测

2. 官方文档链接

  • PyTorch 官方文档
  • MNIST 数据集链接

3. 安装 PyTorch 和依赖库

首先,确保您已经安装了 PyTorch 和相关依赖库。如果没有安装,可以运行以下命令:

pip install torch torchvision matplotlib

4. 加载 MNIST 数据集

我们将使用 torchvision 提供的 MNIST 数据集。它包含 60,000 个训练样本和 10,000 个测试样本。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载 MNIST 训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 加载数据集
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 查看数据集的大小
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")# 可视化部分样本
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0], cmap='gray')plt.title(f"Label: {example_targets[i]}")plt.axis('off')
plt.show()

说明

  • transforms.Compose:我们将图像转换为 PyTorch 张量,并将像素值标准化为 [-1, 1] 的范围。
  • DataLoader:用于将数据集加载为批次,并打乱数据顺序以便训练时使用。

5. 构建卷积神经网络(CNN)

我们将构建一个简单的 CNN 模型,用于手写数字识别。该模型将包含两个卷积层和两个全连接层。

import torch.nn as nn
import torch.nn.functional as F# 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层1: 输入通道为1(灰度图),输出通道为16,卷积核大小为3x3self.conv1 = nn.Conv2d(1, 16, kernel_size=3)# 卷积层2: 输入通道为16,输出通道为32,卷积核大小为3x3self.conv2 = nn.Conv2d(16, 32, kernel_size=3)# 全连接层1: 输入为32*5*5(展平后的特征图),输出为128self.fc1 = nn.Linear(32 * 5 * 5, 128)# 全连接层2: 输入为128,输出为10(10个类别)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积层 + ReLU + 最大池化层x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))# 展平成一维向量x = x.view(-1, 32 * 5 * 5)# 全连接层 + ReLUx = F.relu(self.fc1(x))# 输出层x = self.fc2(x)return x# 实例化模型
model = CNN()
print(model)

说明

  • conv1conv2:卷积层用于提取图像特征。第一个卷积层从 1 个输入通道(灰度图像)转换为 16 个特征图,第二个卷积层将 16 个特征图转换为 32 个特征图。
  • max_pool2d:最大池化层,用于下采样特征图,将特征图尺寸减半。
  • fc1fc2:全连接层,用于将卷积层提取到的特征进行分类。

6. 训练模型

我们将定义损失函数和优化器,然后在训练数据集上训练模型。

import torch.optim as optim# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 将模型移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")print("训练完成!")

说明

  • CrossEntropyLoss:用于分类任务的损失函数,适用于多分类问题。
  • optimizer:使用 Adam 优化器,能够自动调整学习率并加快收敛速度。
  • 训练过程包括前向传播、损失计算、反向传播和参数更新。

7. 评估模型性能

在训练完成后,我们将使用测试数据集来评估模型的性能,计算模型在测试集上的准确率。

# 测试模型
model.eval()  # 切换到评估模式
correct = 0
total = 0with torch.no_grad():  # 关闭梯度计算for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集上的准确率: {100 * correct / total:.2f}%')

说明

  • model.eval():在评估模型时关闭 dropout 和 batch normalization。
  • torch.no_grad():关闭梯度计算以提高测试阶段的效率。

8. 进行预测

最后,我们可以使用训练好的模型对手写数字图像进行预测。

# 从测试集中取出一个样本
example_data, example_target = next(iter(test_loader))
example_data = example_data.to(device)# 使用模型进行预测
model.eval()
with torch.no_grad():output = model(example_data)# 可视化预测结果
plt.figure(figsize=(10, 3))
for i in range(6):plt.subplot(1, 6, i + 1)plt.imshow(example_data[i][0].cpu(), cmap='gray')plt.title(f"预测: {torch.argmax(output[i]).item()}")plt.axis('off')
plt.show()

说明

  • 取出测试集中的一批样本进行预测,并可视化模型的预测结果。

9. 总结

在本教程中,我们使用 PyTorch 实现了手写数字识别任务,构建了一个简单的卷积神经网络(CNN),并在 MNIST 数据集上进行了训练和评估。通过此项目,您可以了解如何加载数据、构建模型、训练、评估和测试 PyTorch 模型。

10. 改进方向

  • 增加网络深度:可以增加卷积层和全连接层的

数量,提高模型的表现。

  • 使用数据增强:通过数据增强技术(旋转、缩放等),可以提高模型的泛化能力。
  • 应用在其他数据集:除了 MNIST,还可以将模型应用到其他数据集,如 FashionMNIST、CIFAR-10 等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54391.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux】kill命令

kill 命令在 Linux 和类 Unix 系统中用于向进程发送信号,默认情况下是发送 SIGTERM(信号 15),请求程序终止运行。如果程序没有响应 SIGTERM 信号,可以使用 SIGKILL(信号 9)强制终止进程&#xf…

java之斗地主部分功能的实现

今天我们要实现斗地主中发牌和洗牌这两个功能,该如何去实现呢? 1.创建牌类:52张牌每一张牌包含两个属性:牌的大小和牌的花色。 故我们优先创建一个牌的类(Card):包含大小和花色。 public class Card { //单张牌的大小及类型/…

无人机+自组网:中继通信增强技术详解

无人机与自组网技术的结合,特别是通过中继通信增强技术,为无人机在复杂环境中的通信提供了稳定、高效、可靠的解决方案。以下是对该技术的详细解析: 一、无人机自组网技术概述 无人机自组网技术是一种利用无人机作为节点,通过无…

proteus仿真学习(1)

一,创建工程 一般选择默认模式,不配置pcb文件 可以选用芯片型号也可以不选 不选则从零开始布局,没有初始最小系统。选用则有初始最小系统以及基础的main函数 本次学习使用从零开始,不配置固件 二,上手软件 1.在元件…

6--SpringBootWeb案例(详解)

目录 环境搭建 部门管理 查询部门 接口文档 代码 删除部门 接口文档 代码 新增部门 接口文档 代码 已有前端,根据接口文档完成后端功能的开发 成品如下: 环境搭建 1. 准备数据库表 (dept 、 emp) -- 部门管理 create table dept( id int un…

深度学习自编码器 - 正则自编码器篇

序言 深度学习领域中,自编码器( Autoencoder \text{Autoencoder} Autoencoder)作为一种无监督学习技术,凭借其独特的结构在数据降维、特征提取、异常检测及数据去噪等方面展现出强大的能力。正则自编码器,作为自编码器…

ES5 在 Web 上的现状

最后一个支持 ES5 的浏览器 IE 11 在 2022 年被微软停止支持,那么今天 Web 上的 ES5 现状如何?在构建生产代码时,Web 开发者的最佳实践是什么? 本文将通过数据来回答这些问题,并基于这些数据为网站开发者和库作者提供一…

Delta Lake如何使用

1. 安装 Java 确保你的系统上安装了 Java 8 或更高版本。可以通过以下命令检查 Java 是否已安装: java -version2. 安装 Apache Spark 下载 Spark: 从 Apache Spark 官方网站 下载适合的版本,建议下载预编译的版本(例如&#xf…

如何有效检测住宅IP真伪?

在当今的互联网时代,住宅IP(即家庭用户通过宽带服务提供商获得的IP地址)在跨境电商、广告投放、网络安全等多个领域扮演着重要角色。然而,随着网络环境的复杂化和欺诈行为的增多,如何有效检测和辨别住宅IP的真伪成为了…

Spring:统一结果私有属性造成的前端无法访问异常报错问题

用户未填写任何评价 1.问题复现 (1)看一段代码 controller: import lombok.extern.slf4j.Slf4j; import org.ljy.testdemo.common.Result; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.w…

深入解析 SQLSugar:从基础 CRUD 到读写分离与高级特性详解

SQLSugar 使用指南:从入门到进阶及高级特性详解 SQLSugar 是一款功能丰富的 .NET ORM 框架,它支持多种数据库、简洁的 API 和优雅的编程体验。相较于其他 ORM,SQLSugar 提供了很多开发者友好的功能,比如自动创建表结构、灵活的查…

在HTML中添加图片

在HTML中添加图片&#xff0c;你需要使用<img>标签。这个标签用于在网页上嵌入图像。<img>是一个空元素&#xff0c;它只包含属性&#xff0c;并且没有闭合标签。要在<img>标签中指定要显示的图像&#xff0c;你需要使用src&#xff08;source的缩写&#xf…

Centos中关闭swap分区,关闭内存交换

概述&#xff1a; Swap 分区是 Linux 系统中扩展物理内存的一种机制。Swap的主要功能是当全部的RAM被占用并需要更多内存时&#xff0c;用磁盘空间代理RAM内存。Swap对虚拟化技术资源损耗非常大&#xff0c;一般虚拟化是不允许开启交换空间的&#xff0c;如果不关闭Swap&…

【Linux课程学习】make/Makefile:Linux项目自动化构建工具

&#x1f381;个人主页&#xff1a;我们的五年 &#x1f50d;系列专栏&#xff1a;Linux课程学习 &#x1f337;追光的人&#xff0c;终会万丈光芒 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 &#x1f349;一.make/Makefile的理解&#xff1a; …

关于STM32项目面试题02:ADC与DAC篇(输入部分NTC、AV:0-5V、AI:4-20mA和DAC的两个引脚)

博客的风格是&#xff1a;答案一定不能在问题的后面&#xff0c;要自己想、自己背&#xff1b;回答都是最精简、最精简、最精简&#xff0c;可能就几个字&#xff0c;你要自己自信的展开。 面试官01&#xff1a;什么是模数转换/ADC&#xff1f;说说模数转换的流程&#xff1f; …

mysql5.7常用操作命令手册

文章目录 前言一、关闭mysql服务1.mha节点,关闭MHA高可用2.主节点&#xff0c;摘掉vip&#xff0c;停掉mysql服务3.从节点&#xff0c;停掉mysql服务 二、启动mysql1.启动数据库顺序2.主节点&#xff0c;登陆数据库检查主库状态,将主库改成读写状态3.从节点启动配置数据库&…

基于SpringBoot+Vue+MySQL的手机销售管理系统

系统展示 用户前台界面 管理员后台界面 商家后台界面 系统背景 随着智能手机的普及和市场竞争的日益激烈&#xff0c;手机销售行业面临着前所未有的挑战与机遇。传统的手工记录和简单的电子表格管理方式已难以满足现代手机销售业务的需求&#xff0c;销售数据的混乱和管理效率低…

Python基础知识——字典排序(不断补充)

目录 专栏导读代码1&#xff1a;value是多个字符拼接(含拼接符号)(升序)代码2&#xff1a;value是单个值(升序)代码3&#xff1a;按值排序(升序)代码4&#xff1a;按值排序(降序)总结 专栏导读 &#x1f338; 欢迎来到Python办公自动化专栏—Python处理办公问题&#xff0c;解放…

技术成神之路:设计模式(十四)享元模式

介绍 享元模式&#xff08;Flyweight Pattern&#xff09;是一种结构性设计模式&#xff0c;旨在通过共享对象来有效地支持大量细粒度的对象。 1.定义 享元模式通过将对象状态分为内部状态&#xff08;可以共享&#xff09;和外部状态&#xff08;不可共享&#xff09;&#xf…

C语言-文件操作-一些我想到的、见到的奇怪的问题

博客主页&#xff1a;【夜泉_ly】 本文专栏&#xff1a;【C语言】 欢迎点赞&#x1f44d;收藏⭐关注❤️ C语言-文件操作-一些我想到的、见到的奇怪的问题 前言1.在不关闭文件的情况下&#xff0c;连续多次调用 fopen() 打开同一个文件&#xff0c;会发生什么&#xff1f;1.1过…