手写数字识别案例分析(torch,深度学习入门)

在人工智能和机器学习的广阔领域中,手写数字识别是一个经典的入门级问题,它不仅能够帮助我们理解深度学习的基本原理,还能作为实践编程和模型训练的良好起点。本文将带您踏上手写数字识别的深度学习之旅,从数据集介绍、模型构建到训练与评估,一步步深入探索。

一、引言

手写数字识别(Handwritten Digit Recognition)是指通过计算机程序自动识别手写数字的过程。最著名的手写数字数据集之一是MNIST(Modified National Institute of Standards and Technology database),它包含了大量的手写数字图片,每张图片都被标记了对应的数字(0-9)。这个数据集成为了初学者学习深度学习,尤其是卷积神经网络(CNN)的首选。

二、MNIST数据集简介

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本都是一张28x28像素的灰度图像,代表了一个手写数字。这些图像已经被归一化并居中在图像中心,使得数字不会受到位置变化的影响。

 PyTorch 和 torchvision 库来下载并准备 MNIST 数据集,包括训练集和测试集

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下载训练数据集(图片+标签)'''
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
  1. 打印设备信息:您的代码已经很好地检查了CUDA和MPS(针对Apple M系列芯片)的可用性,并设置了相应的设备。但是,在打印设备信息时,有一个小错误在字符串格式化上。您需要确保在字符串中正确地包含变量名。

  2. 打印数据形状:您已经正确地设置了DataLoader并打印了测试数据集中的一个批次的数据和标签的形状。这是一个很好的实践,可以帮助您了解数据的维度。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 通常训练时会打乱数据  
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试时不需要打乱数据  # 打印测试数据集的一个批次的数据和标签的形状  
for x, y in test_dataloader:  print(f"Shape of x [N,C,H,W]: {x.shape}")  # 注意这里的x是图像,但MNIST是灰度图,所以C=1  print(f"Shape of y: {y.shape}, {y.dtype}")  # y是标签,通常是一维的,且为long类型  break  # 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU  
device = "cuda" if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else "cpu")  
print(f"Using {device} device")  # 确保在字符串中正确地包含了变量名  

三、训练模型选择

一、创建一个具有多个隐藏层的神经网络,这些层都使用了nn.Linear来定义全连接层,并使用torch.sigmoid作为激活函数。

import torch  
import torch.nn as nn  class NeuralNetwork(nn.Module):  def __init__(self):  super().__init__()  self.flatten = nn.Flatten()  self.hidden1 = nn.Linear(28 * 28, 256)  self.relu1 = nn.ReLU()  self.hidden2 = nn.Linear(256, 128)  self.relu2 = nn.ReLU()  self.hidden3 = nn.Linear(128, 64)  self.relu3 = nn.ReLU()  self.hidden4 = nn.Linear(64, 32)  self.relu4 = nn.ReLU()  self.out = nn.Linear(32, 10)  # 输出层对应于10个类别的得分  def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x model = NeuralNetwork().to(device)  
print(model)  

二、定义了一个具有三个卷积层的CNN,每个卷积层后面都跟着ReLU激活函数,前两个卷积层后面还跟着最大池化层。最后,通过一个全连接层将卷积层的输出转换为10个类别的得分。

import torch  
import torch.nn as nn  class CNN(nn.Module):  def __init__(self):  super(CNN, self).__init__()  self.conv1 = nn.Sequential(  nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),  nn.ReLU(),  nn.MaxPool2d(kernel_size=2),  )  self.conv2 = nn.Sequential(  nn.Conv2d(16, 32, 5, 1, 2),  nn.ReLU(),  nn.Conv2d(32, 32, 5, 1, 2),  nn.ReLU(),  nn.MaxPool2d(2),  )  self.conv3 = nn.Sequential(  nn.Conv2d(32, 64, 5, 1, 2),  nn.ReLU(),  )  self.out = nn.Linear(64 * 7 * 7, 10)  # 确保这里的输入特征数与卷积层输出后的特征数相匹配  def forward(self, x):  x = self.conv1(x)  x = self.conv2(x)  x = self.conv3(x)  # 输出应为(batch_size, 64, 7, 7)  x = x.view(x.size(0), -1)  # 展平操作,输出为(batch_size, 64*7*7)  output = self.out(x)  return output  model = CNN().to(device)  
print(model)
  • in_channels=1:这指定了输入图像的通道数。

  • out_channels=16:这指定了卷积操作后输出的通道数,也就是卷积核(或称为滤波器)的数量。

  • kernel_size=5:这定义了卷积核的大小。

  • stride=1:这指定了卷积核在输入数据上滑动的步长。

  • padding=2:这定义了要在输入数据周围添加的零填充(zero-padding)的数量。

四、处理数据集和测试集

训练集处理:

def train(dataloader, model, loss_fn, optimizer):  model.train()  # 将模型设置为训练模式  batch_size_num = 1  # 这不是标准的用法,但在这里用作计数已处理批次的数量  for x, y in dataloader:  # 遍历数据加载器中的每个批次  x, y = x.to(device), y.to(device)  # 将数据和标签移动到指定的设备(如GPU)  pred = model(x)  # 通过模型进行前向传播  loss = loss_fn(pred, y)  # 计算预测和真实标签之间的损失  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向传播,计算当前梯度  optimizer.step()  # 更新模型的权重  loss_value = loss.item()if batch_size_num % 200 == 0:print(f"{loss_value:>7f}[number:{batch_size_num}]")#打印结果batch_size_num += 1  # 增加已处理批次的数量

测试集处理:

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

模型训练:

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for t in range(epochs):print(f"-----------------------------------------------\nepcho{t+1}")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model, loss_fn)

结果:

神经网络:

cnn:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/54723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue Echart使用

一、在vue中使用Echarts 1.安装Echarts npm install echarts --save2.准备一个呈现图表的盒子 给盒子起名字是建议使用id选择器 这个盒子通常来说就是我们熟悉的 div &#xff0c;这个 div 决定了图表显示在哪里&#xff0c;盒子一定要指定宽和高 <div id"main&quo…

性能测试问题诊断-接口耗时高

问题现象&#xff1a;近期发现每晚跑的定时压测任务&#xff0c;压测结果中&#xff0c;各接口耗时变高&#xff08;原来99耗时<3秒&#xff0c;当前99耗时>20秒)。 问题排查&#xff1a; 查看jmeter 生成的结果图&#xff0c;发现压测2分钟后出现接口耗时变高情况。如下…

【ShuQiHere】 探索数据挖掘的世界:从概念到应用

&#x1f310; 【ShuQiHere】 数据挖掘&#xff08;Data Mining, DM&#xff09; 是一种从大型数据集中提取有用信息的技术&#xff0c;无论是在商业分析、金融预测&#xff0c;还是医学研究中&#xff0c;数据挖掘都扮演着至关重要的角色。本文将带您深入了解数据挖掘的核心概…

如何快速上手一个Github的开源项目

程序研发领域正是有一些热衷开源的小伙伴&#xff0c;技能迭代才能如此的迅速&#xff0c;因此&#xff0c;快速上手一个GitHub上的开源项目&#xff0c;基本上已经变成很个程序员小伙伴必须掌握的技能&#xff0c;因为终究你会应用到其中的一个或多个项目&#xff0c;帮助自己…

【计算机网络篇】电路交换,报文交换,分组交换

本文主要介绍计算机网络中的电路交换&#xff0c;报文交换&#xff0c;分组交换&#xff0c;文中的内容是我认为的重点内容&#xff0c;并非所有。参考的教材是谢希仁老师编著的《计算机网络》第8版。跟学视频课为河南科技大学郑瑞娟老师所讲计网。 目录 &#x1f3af;一.划分…

【Windows 同时安装 MySQL5 和 MySQL8 - 详细图文教程】

卸载 MySQL 参考文章&#xff1a; 完美解决Mysql彻底删除并重装_怎么找到mysql并卸载-CSDN博客使用命令卸载mysql_卸载mysql服务命令-CSDN博客 先管理员方式打开 cmd &#xff0c;切换到 MySQL 安装目录的 bin 文件夹下&#xff0c;执行如下命令&#xff0c;删除 MySQL 服务mys…

Blender软件三大渲染器Eevee、Cycles、Workbench对比解析

Blender 是一款强大的开源3D制作平台&#xff0c;提供了从建模、雕刻、动画到渲染、后期制作的一整套工具&#xff0c;广泛应用于电影、游戏、建筑、艺术等领域。 渲染101云渲染云渲6666 相比于其他平台&#xff0c;如 Autodesk Maya、3ds Max 或 Cinema 4D&#xff0c;Blende…

neo4j关系的创建删除 图的删除

关系的创建和删除 关系创建 CREATE (:Person {name:"jack"})-[:LOVE]->(:Person {name:"Rose"})已有这个关系时&#xff0c;merge不起效果 MERGE (:Person {name:"Jack" })-[:LOVE]->(:Person {name:"Rose"})关系兼顾节点和关…

C++ 进阶之路:非类型模板参数、模板特化与分离编译详解

目录 非类型模版参数 类型模板参数 非类型模板参数 非类型模板参数的使用 模板的特化 函数模板的特化 类模板的特化 全特化与偏特化 偏特化的其它情况 模板的分离编译 什么是分离编译 为什么要分离编译 为什么模板不能分离编译 普通的类和函数都是可以分离编译的…

【学习笔记】数据结构(六 ①)

树和二叉树 &#xff08;一&#xff09; 文章目录 树和二叉树 &#xff08;一&#xff09;6.1 树(Tree)的定义和基本术语6.2 二叉树6.2.1 二叉树的定义1、斜树2、满二叉树3、完全二叉树4、二叉排序树5、平衡二叉树&#xff08;AVL树&#xff09;6、红黑树 6.2.2 二叉树的性质6.…

通用大模型 vs 垂直大模型:谁将赢得AI战场?

引言 在人工智能领域&#xff0c;大模型的快速发展引发了广泛的关注和讨论。大模型&#xff0c;尤其是基于深度学习和海量数据训练的模型&#xff0c;已经在多个领域展现出强大的能力。从自然语言处理、图像识别到自动驾驶和医疗诊断&#xff0c;AI大模型正深刻改变着我们的生…

基于二自由度汽车模型的汽车质心侧偏角估计

一、质心侧偏角介绍 在车辆坐标系中&#xff0c;质心侧偏角通常定义为质心速度方向与车辆前进方向的夹角。如下图所示&#xff0c;u为车辆前进方向&#xff0c;v为质心速度方向&#xff0c;u和v之间的夹角便是质心侧偏角。 质心侧偏角的作用有如下三点&#xff1a; 1、稳定性…

SVN笔记-SVN安装

SVN笔记-SVN安装 1、在windows下安装 SVN 1、准备svn的安装文件 下载地址&#xff1a;https://sourceforge.net/projects/win32svn/ 2、下载完成后&#xff0c;在相应的盘符中会有一个Setup-Subversion-1.8.17.msi的文件&#xff0c;目前最新的版本是1.8.17&#xff0c; 这里…

opencv4.5.5 GPU版本编译

一、安装环境 1、opencv4.5.5 下载地址&#xff1a;https://github.com/opencv/opencv/archive/refs/tags/4.5.5.ziphttps://gitee.com/mirrors/opencv/tree/4.5.0 2、opencv-contrib4.5.5 下载地址&#xff1a;https://github.com/opencv/opencv_contrib/archive/refs/tags/4…

Python中的数据可视化:从基础图表到高级可视化

数据可视化是数据分析和科学计算中不可或缺的一部分。它通过图形化的方式呈现数据&#xff0c;使复杂的统计信息变得直观易懂。Python提供了多种强大的库来支持数据可视化&#xff0c;如Matplotlib、Seaborn、Plotly等。本文将从基础图表入手&#xff0c;逐步介绍如何使用这些库…

【第十一章:Sentosa_DSML社区版-机器学习之分类】

目录 11.1 逻辑回归分类 11.2 决策树分类 11.3 梯度提升决策树分类 11.4 XGBoost分类 11.5 随机森林分类 11.6 朴素贝叶斯分类 11.7 支持向量机分类 11.8 多层感知机分类 11.9 LightGBM分类 11.10 因子分解机分类 11.11 AdaBoost分类 11.12 KNN分类 【第十一章&…

Java语言程序设计基础篇_编程练习题***18.33 (游戏:骑士旅途的动画)

目录 ***18.33 (游戏:骑士旅途的动画) 习题思路 代码示例 动画演示 ***18.33 (游戏:骑士旅途的动画) 为骑士旅途的问题编写一个程序&#xff0c;该程序应该允许用户将骑士放到任何一个起始正方形&#xff0c;并单击Solve按钮&#xff0c;用动画展示骑士沿着路径的移动&…

深度学习之表示学习 - 贪心逐层无监督预训练篇

引言 在人工智能的浩瀚星空中&#xff0c;深度学习以其强大的数据处理与模式识别能力&#xff0c;成为了一颗璀璨的明星。而表示学习&#xff0c;作为深度学习的核心基石之一&#xff0c;正引领着这一领域不断突破边界。表示学习旨在将原始数据转换为更加抽象、更有意义的特征…

leetcode第二十六题:删去有序数组的重复项

给你一个 非严格递增排列 的数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使每个元素 只出现一次 &#xff0c;返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数量为 k &#xff0c;你…

Rasa对话模型——做一个语言助手

1、Rasa模型 1.1 模型介绍 Rasa是一个用于构建对话 AI 的开源框架&#xff0c;主要用于开发聊天机器人和语音助手。Rasa 提供了自然语言理解&#xff08;NLU&#xff09;和对话管理&#xff08;DM&#xff09;功能&#xff0c;使开发者能够创建智能、交互式的对话系统。 1.2…