# 手写数字识别:使用PyTorch构建MNIST分类器

手写数字识别:使用PyTorch构建MNIST分类器

在这篇文章中,我将引导你通过使用PyTorch框架构建一个简单的神经网络模型,用于识别MNIST数据集中的手写数字。MNIST数据集是一个经典的机器学习数据集,包含了60,000张训练图像和10,000张测试图像,每张图像都是28x28像素的灰度手写数字。
在这里插入图片描述

在这里插入图片描述

环境准备

首先,确保你的环境中安装了PyTorch和torchvision。可以通过以下命令安装:

pip install torch torchvision

数据加载与预处理

我们首先加载MNIST数据集,并将图像转换为PyTorch张量格式,以便模型可以处理。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST( #跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击 training_data:Datasetroot="data",#表示下载的手写数字 到哪个路径。60000train=True, #读取下载后的数据 中的 训练集download=True,#如果你之前已经下载过了,就不用再下载transform=ToTensor(), #张量,图片是不能直接传入神经网络模型
)   #对于pytorch库能够识别的数据一般是tensor张量。'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
print(len(training_data))

数据可视化

为了更好地理解数据,我们可以展示一些手写数字图像。

''展示手写字图片,把训练数据集中的前59000张图片展示一下'''from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3, 3, i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off") # plt.show(I)#是示矢量,plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

创建DataLoader

为了高效地加载数据,我们使用DataLoader来批量加载数据。

# '"创建数据DataLoader(数据加载器)开'
#  'batch_size:将数据集分成多份,每一份为batch_size个数据'
#  '优点:可以减少内存的使用,提高训练速度。train_dataloader = DataLoader(training_data, batch_size=64) #64张图片为一个包,train_dataloader:<torch
test_dataloader = DataLoader(test_data, batch_size=64)

模型定义

接下来,我们定义一个简单的神经网络模型,包含两个隐藏层和一个输出层。

'''定义神经网络类的继承这种方式'''
class NeuralNetwork(nn.Module):  #通过调用类的形式来使用神经网络,神经网络的模型,nn.moduledef __init__(self): #python基础关于类,self类自已本身super().__init__() #继承的父类初始化self.flatten = nn.Flatten() #展开,创建一个展开对象flattenself.hidden1 = nn.Linear(28*28, 128 ) #第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = self.flatten(x) #图像进行展开x = self.hidden1(x)x = torch.relu(x) #激活函数,torch使用的relu函数 relu,tanhx = self.hidden2(x)x = torch.relu(x)x = self.out(x)return xmodel = NeuralNetwork().to(device) #把刚刚创建的模型传入到Gpu
print(model)

训练与测试

我们定义训练和测试函数,使用交叉熵损失函数和随机梯度下降优化器。

def train(dataloader, model, loss_fn, optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
# #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()batch_size_num = 1for X, y in dataloader: #其中batch为每一个数据的编号X, y = X.to(device), y.to(device) #把训练数据集和标签传入cpu或GPUpred = model.forward(X) #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化wloss= loss_fn(pred, y) #通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() #梯度值清零loss.backward() #反向传播计算得到每个参数的梯度值woptimizer.step() #根据梯度更新网络w参数loss_value = loss.item() #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset) #10000num_batches = len(dataloader) #打包的数量model.eval() #测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad(): #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  #test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能来衡量模型测试的好坏。correct /= size #平均的正确率print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

训练模型

最后,我们训练模型并测试其性能。

loss_fn = nn.CrossEntropyLoss() #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()# #lr:learning_rate学习率,也就是步长#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

运行结果

在这里插入图片描述

结论

通过这篇文章,我们成功构建了一个简单的神经网络模型来识别MNIST数据集中的手写数字。这个模型展示了如何使用PyTorch进行数据处理、模型定义、训练和测试。希望这能帮助你开始自己的深度学习项目!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化学习笔记(三)——表格型方法(蒙特卡洛、时序差分)

强化学习笔记&#xff08;三&#xff09;——表格型方法&#xff08;蒙特卡洛、时序差分&#xff09; 一、马尔可夫决策过程二、Q表格三、免模型预测1. 蒙特卡洛策略评估1) 动态规划方法和蒙特卡洛方法的差异 2. 时序差分2.1 时序差分误差2.2 时序差分方法的推广 3. 自举与采样…

c++_csp-j算法 (4)

迪克斯特拉() 介绍 迪克斯特拉算法(Dijkstra算法)是一种用于解决单源最短路径问题的经典算法,由荷兰计算机科学家艾兹赫尔迪克斯特拉(Edsger W. Dijkstra)于1956年提出。迪克斯特拉算法的基本思想是通过逐步扩展已经找到的最短路径集合,逐步更新节点到源节点的最短路…

(13)VTK C++开发示例 --- 透视变换

文章目录 1. 概述2. CMake链接VTK3. main.cpp文件4. 演示效果 更多精彩内容&#x1f449;内容导航 &#x1f448;&#x1f449;VTK开发 &#x1f448; 1. 概述 在VTK&#xff08;Visualization Toolkit&#xff09;中&#xff0c;vtkPerspectiveTransform 和 vtkTransform 都是…

深入探索Qt异步编程--从信号槽到Future

概述 在现代软件开发中,应用程序的响应速度和用户体验是至关重要的。尤其是在图形用户界面(GUI)应用中,长时间运行的任务如果直接在主线程执行会导致界面冻结,严重影响用户体验。 Qt提供了一系列工具和技术来帮助开发者实现异步编程,从而避免这些问题。本文将深入探讨Qt…

基于Python的图片/签名转CAD小工具开发方案

基于Python的图片/签名转CAD工具开发方案 一、项目背景 传统设计流程中&#xff0c;设计师常常需要将手写签名或扫描图纸转换为CAD格式。本文介绍如何利用Python快速开发图像矢量化工具&#xff0c;实现&#xff1a; &#x1f4f7; 图像自动预处理✏️ 轮廓精确提取⚙️ 参数…

【仓颉 + 鸿蒙 + AI Agent】CangjieMagic框架(17):PlanReactExecutor

CangjieMagic框架&#xff1a;使用华为仓颉编程语言编写&#xff0c;专门用于开发AI Agent&#xff0c;支持鸿蒙、Windows、macOS、Linux等系统。 这篇文章剖析一下 CangjieMagic 框架中的 PlanReactExecutor。 1 PlanReactExecutor的工作原理 #mermaid-svg-OqJUCSoxZkzylbDY…

一文了解相位阵列天线中的真时延

本文要点 真时延是宽带带相位阵列天线的关键元素之一。 真时延透过在整个信号频谱上应用可变相移来消除波束斜视现象。 在相位阵列中使用时延单元或电路板&#xff0c;以提供波束控制和相移。 市场越来越需要更快、更可靠的通讯网络&#xff0c;而宽带通信系统正在努力满…

Java中 关于编译(Compilation)、类加载(Class Loading) 和 运行(Execution)的详细区别解析

以下是Java中 编译&#xff08;Compilation&#xff09;、类加载&#xff08;Class Loading&#xff09; 和 运行&#xff08;Execution&#xff09; 的详细区别解析&#xff1a; 1. 编译&#xff08;Compilation&#xff09; 定义 将Java源代码&#xff08;.java文件&#x…

【KWDB 创作者计划】_深度学习篇---松科AI加速棒

文章目录 前言一、简介二、安装与配置硬件连接驱动安装软件环境配置三、使用步骤初始化设备调用SDK接口检测设备状态:集成到AI项目四、注意事项兼容性散热固件更新安全移除五、硬件架构与技术规格核心芯片专用AI处理器内存配置接口类型物理接口虚拟接口能效比散热设计六、软件…

如何清理Windows系统中已失效或已删除应用的默认打开方式设置

在使用Windows系统的过程中&#xff0c;我们可能会遇到一些问题&#xff1a;某些已卸载或失效的应用程序仍然出现在默认打开方式的列表中&#xff0c;这不仅显得杂乱&#xff0c;还可能影响我们快速找到正确的程序来打开文件。 如图&#xff0c;显示应用已经被geek强制删除&am…

NFC碰一碰发视频推广工具开发注意事项丨支持OEM搭建

随着线下门店短视频推广需求的爆发&#xff0c;基于NFC技术的“碰一碰发视频”推广工具成为商业热点。集星引擎在开发同类系统时&#xff0c;总结出六大核心开发注意事项&#xff0c;帮助技术团队与品牌方少走弯路&#xff0c;打造真正贴合商户需求的实用型工具&#xff1a; 一…

pgsql中使用jsonb的mybatis-plus和Spring Data JPA的配置

在pgsql中使用jsonb类型的数据时&#xff0c;实体对象要对其进行一些相关的配置&#xff0c;而mybatis和jpa中使用各不相同。 在项目中经常会结合 MyBatis-Plus 和 JPA 进行开发&#xff0c;MyBatis_plus对于操作数据更灵活&#xff0c;jpa可以自动建表&#xff0c;两者各取其…

kotlin + spirngboot3 + spring security6 配置登录与JWT

1. 导包 implementation("com.auth0:java-jwt:3.14.0") implementation("org.springframework.boot:spring-boot-starter-security")配置用户实体类 Entity Table(name "users") data class User(IdGeneratedValue(strategy GenerationType.I…

【JavaWeb后端开发03】MySQL入门

文章目录 1. 前言1.1 引言1.2 相关概念 2. MySQL概述2.1 安装2.2 连接2.2.1 介绍2.2.2 企业使用方式(了解) 2.3 数据模型2.3.1 **关系型数据库&#xff08;RDBMS&#xff09;**2.3.2 数据模型 3. SQL语句3.1 DDL语句3.1.1 数据库操作3.1.1.1 查询数据库3.1.1.2 创建数据库3.1.1…

人工智能在智能家居中的应用与发展

随着人工智能&#xff08;AI&#xff09;技术的飞速发展&#xff0c;智能家居逐渐成为现代生活的重要组成部分。从智能语音助手到智能家电&#xff0c;AI正在改变我们与家居环境的互动方式&#xff0c;让生活更加便捷、舒适和高效。本文将探讨人工智能在智能家居中的应用现状、…

【EasyPan】项目常见问题解答(自用持续更新中…)

EasyPan 网盘项目介绍 一、项目概述 EasyPan 是一个基于 Vue3 SpringBoot 的网盘系统&#xff0c;支持文件存储、在线预览、分享协作及后台管理&#xff0c;技术栈涵盖主流前后端框架及中间件&#xff08;MySQL、Redis、FFmpeg&#xff09;。 二、核心功能模块 用户认证 注册…

4.1腾讯校招简历优化与自我介绍攻略:公式化表达+结构化呈现

腾讯校招简历优化与自我介绍攻略&#xff1a;公式化表达结构化呈现 在腾讯校招中&#xff0c;简历是敲开面试大门的第一块砖&#xff0c;自我介绍则是展现个人魅力的黄金30秒。本文结合腾讯面试官偏好&#xff0c;拆解简历撰写公式、自我介绍黄金结构及分岗位避坑指南&#xf…

【Easylive】consumes = MediaType.MULTIPART_FORM_DATA_VALUE 与 @RequestPart

【Easylive】项目常见问题解答&#xff08;自用&持续更新中…&#xff09; 汇总版 consumes MediaType.MULTIPART_FORM_DATA_VALUE 的作用 1. 定义请求的数据格式 • 作用&#xff1a;告诉 Feign 和 HTTP 客户端&#xff0c;这个接口 接收的是 multipart/form-data 格式的…

OpenSSL1.1.1d windows安装包资源使用

环境&#xff1a; QT版本&#xff1a;5.14.2 用途: openssl1.1.1d版本 问题描述&#xff1a; 今天尝试用百度云人脸识别api搭载QT的人脸识别程序&#xff0c;需要用到 QNetworkManager 访问 https 开头的网址。 但是遇到了QT缺乏 openssl 的相关问题&#xff0c;找了大半天…

代码实战保险花销预测

文章目录 摘要项目地址实战代码&#xff08;初级版&#xff09;实战代码&#xff08;进阶版&#xff09; 摘要 本文介绍了一个完整的机器学习流程项目&#xff0c;重点涵盖了多元线性回归的建模与评估方法。项目详细讲解了特征工程中的多项实用技巧&#xff0c;包括&#xff1…