使用残差网络识别手写数字及MNIST 数据集介绍

MNIST 数据集已经是一个几乎每个初学者都会接触的数据集, 很多实验、很多模型都会以MNIST 数据集作为训练对象, 不过有些人可能对它还不是很了解, 那么今天我们一起来学习一下MNIST 数据集。

1.MNIST 介绍

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

MNIST 数据集包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

在这里插入图片描述

Size: 28×28 灰度手写数字图像
Num: 训练集 60000 和 测试集 10000,一共70000张图片
Classes: 0,1,2,3,4,5,6,7,8,9

在这里插入图片描述

2.数据集读取

2.1官网下载MNIST 数据集

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取
注意:不要直接点连接,复制连接粘贴到新的浏览器标签页搜索,就不需要账号密码了!!!

2.2 博主分享

同时:博主已将MNIST公开数据集上传至百度网盘,大家可以直接下载学习:
链接: https://pan.baidu.com/s/1-rurbkWdv_veQD8QcQWcRw 提取码: 0213

2.3 直接下载

如果数据集没有下载,修改参数:download=True,直接去下载数据集:

from torchvision import datasets, transformstrain_data = datasets.MNIST(root="./MNIST",train=True,transform=transforms.ToTensor(),download=True)test_data = datasets.MNIST(root="./MNIST",train=False,transform=transforms.ToTensor(),download=True)print(train_data)
print(test_data)

如果出现这种错误:

SyntaxError: Non-UTF-8 code starting with \xca in fileD:\PycharmProjects\model-fuxian\data set\MNIST t.py on line 2, but noencoding declared; see http://python.org/dev/peps/pep-0263/ fordetails

大概率是你没加:# coding:gbk,为什么呢?由于 Python 默认使用 ASCII 编码来解析源代码,因此如果源文件中包含了非 ASCII 编码的字符(比如中文字符),那么解释器就可能会抛出 SyntaxError 异常。加上# -- coding: gbk --这样的注释语句可以告诉解释器当前源文件的字符编码格式是 GBK,从而避免源文件中文字符被错误地解析。

如果成功运行会出现这种结果,表示已经开始下载了:

在这里插入图片描述
输出结果:

Dataset MNIST
Number of datapoints: 60000
Root location: ./MNIST
Split: Train
StandardTransform
Transform: ToTensor()
Dataset MNIST
Number of datapoints: 10000
Root location: ./MNIST
Split: Test
StandardTransform
Transform: ToTensor()

3.数据集可视化

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plttrain_data = datasets.MNIST(root="model-fuxian/data set/MNIST/MNIST/raw/MNIST",train=True,transform=transforms.ToTensor(),download=False)train_loader = DataLoader(dataset=train_data,batch_size=64,shuffle=True)for num, (image, label) in enumerate(train_loader):image_batch = torchvision.utils.make_grid(image, padding=2)plt.imshow(np.transpose(image_batch.numpy(), (1, 2, 0)), vmin=0, vmax=255)plt.show()print(label)

得到图片:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

这是标签:

tensor([2, 1, 7, 7, 2, 4, 2, 2, 0, 1, 7, 1, 5, 7, 9, 0, 2, 7, 4, 7, 0, 2, 7, 1,6, 9, 1, 1, 1, 5, 4, 3, 8, 0, 1, 0, 1, 3, 8, 0, 1, 4, 5, 1, 8, 4, 7, 3,8, 3, 2, 2, 0, 0, 4, 0, 2, 9, 7, 1, 8, 3, 2, 3])
tensor([6, 6, 7, 2, 5, 4, 0, 3, 4, 6, 1, 4, 1, 9, 2, 2, 8, 7, 5, 7, 9, 6, 6, 7,1, 9, 9, 5, 5, 6, 9, 6, 8, 5, 5, 7, 8, 9, 8, 3, 1, 0, 1, 4, 6, 1, 8, 6,1, 4, 6, 7, 1, 9, 5, 4, 3, 4, 6, 1, 7, 3, 7, 6])
tensor([7, 1, 5, 1, 4, 0, 9, 2, 2, 0, 1, 5, 2, 3, 6, 4, 6, 9, 3, 3, 2, 8, 1, 5,8, 0, 1, 4, 5, 6, 2, 6, 4, 9, 2, 0, 7, 2, 0, 1, 2, 4, 4, 6, 5, 9, 1, 2,5, 3, 3, 8, 8, 3, 4, 5, 2, 6, 0, 0, 8, 7, 1, 7])

4.使用残差网络RESNET识别手写数字

没有GPU的可以使用CPU,不过速度会大打折扣:DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu"),最好是可以使用GPU,这样速度会快很多: torch.device("cuda")#使用GPU

# coding=gbk
# 1.加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import argparse# 2.超参数
BATCH_SIZE = 32#每批处理的数据 一次性多少个
DEVICE  = torch.device("cuda")#使用GPU
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")#使用GPU
EPOCHS =4 #训练数据集的轮次
# 3.图像处理
pipeline = transforms.Compose([transforms.ToTensor(), #将图片转换为Tensor])# 4.下载,加载数据
from torch.utils.data import DataLoader#下载
train_set = datasets.MNIST("data",train=True,download=True,transform=pipeline)
test_set = datasets.MNIST("data",train=False,download=True,transform=pipeline)#加载 一次性加载BATCH_SIZE个打乱顺序的数据
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)# 5.构建网络模型
class ResBlk(nn.Module):  # 定义Resnet Block模块"""resnet block"""def __init__(self, ch_in, ch_out, stride=1):  # 进入网络前先得知道传入层数和传出层数的设定""":param ch_in::param ch_out:"""super(ResBlk, self).__init__()  # 初始化# we add stride support for resbok, which is distinct from tutorials.# 根据resnet网络结构构建2个(block)块结构 第一层卷积 卷积核大小3*3,步长为1,边缘加1self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)# 将第一层卷积处理的信息通过BatchNorm2dself.bn1 = nn.BatchNorm2d(ch_out)# 第二块卷积接收第一块的输出,操作一样self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)# 确保输入维度等于输出维度self.extra = nn.Sequential()  # 先建一个空的extraif ch_out != ch_in:# [b, ch_in, h, w] => [b, ch_out, h, w]self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):  # 定义局部向前传播函数out = F.relu(self.bn1(self.conv1(x)))  # 对第一块卷积后的数据再经过relu操作out = self.bn2(self.conv2(out))  # 第二块卷积后的数据输出out = self.extra(x) + out  # 将x传入extra经过2块(block)输出后与原始值进行相加out = F.relu(out)  # 调用relureturn outclass ResNet18(nn.Module):  # 构建resnet18层def __init__(self):super(ResNet18, self).__init__()self.conv1 = nn.Sequential(  # 首先定义一个卷积层nn.Conv2d(1, 32, kernel_size=3, stride=3, padding=0),nn.BatchNorm2d(32))# followed 4 blocks 调用4次resnet网络结构,输出都是输入的2倍self.blk1 = ResBlk(32, 64, stride=1)self.blk2 = ResBlk(64, 128, stride=1)self.blk3 = ResBlk(128, 256, stride=1)self.blk4 = ResBlk(256, 256, stride=1)self.outlayer = nn.Linear(256 * 1 * 1, 10)  # 最后是全连接层def forward(self, x):  # 定义整个向前传播x = F.relu(self.conv1(x))  # 先经过第一层卷积x = self.blk1(x)  # 然后通过4次resnet网络结构x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)x = F.adaptive_avg_pool2d(x, [1, 1])# print('after pool:', x.shape)x = x.view(x.size(0), -1)  # 平铺一维值x = self.outlayer(x)  # 全连接层return x
# 6.定义优化器
model = ResNet18().to(DEVICE)#创建模型并将模型加载到指定设备上optimizer = optim.Adam(model.parameters(),lr=0.001)#优化函数criterion = nn.CrossEntropyLoss()
# 7.训练
def train_model(model,device,train_loader,optimizer,epoch):# Training settingsparser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')parser.add_argument('--epochs', type=int, default=14, metavar='N',help='number of epochs to train (default: 14)')parser.add_argument('--lr', type=float, default=1.0, metavar='LR',help='learning rate (default: 1.0)')parser.add_argument('--gamma', type=float, default=0.7, metavar='M',help='Learning rate step gamma (default: 0.7)')parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')parser.add_argument('--dry-run', action='store_true', default=False,help='quickly check a single pass')parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')parser.add_argument('--save-model', action='store_true', default=False,help='For Saving the current Model')args = parser.parse_args()model.train()#模型训练for batch_index,(data ,target) in enumerate(train_loader):data,target = data.to(device),target.to(device)#部署到DEVICE上去optimizer.zero_grad()#梯度初始化为0output = model(data)#训练后的结果loss = criterion(output,target)#多分类计算损失loss.backward()#反向传播 得到参数的梯度值optimizer.step()#参数优化if batch_index % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_index * len(data), len(train_loader.dataset),100. * batch_index / len(train_loader), loss.item()))if args.dry_run:break
# 8.测试
def test_model(model,device,text_loader):model.eval()#模型验证correct = 0.0#正确率global Accuracytext_loss = 0.0with torch.no_grad():#不会计算梯度,也不会进行反向传播for data,target in text_loader:data,target = data.to(device),target.to(device)#部署到device上output = model(data)#处理后的结果text_loss += criterion(output,target).item()#计算测试损失pred = output.argmax(dim=1)#找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()#累计正确的值text_loss /= len(test_loader.dataset)#损失和/加载的数据集的总数Accuracy = 100.0*correct / len(text_loader.dataset)print("Test__Average loss: {:4f},Accuracy: {:.3f}\n".format(text_loss,Accuracy))
# 9.调用for epoch in range(1,EPOCHS+1):train_model(model,DEVICE,train_loader,optimizer,epoch)test_model(model,DEVICE,test_loader)torch.save(model.state_dict(),'model.ckpt')

在这里插入图片描述

精确度

Test__Average loss: 0.000808,Accuracy: 99.150
最后可以发现准确度达到了99%还高,可以看出来残差网络识别手写数字的准确性还是很高的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/638773.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2. SpringBoot3 实战之用户模块接口开发

文章目录 开发模式和环境搭建开发模式环境搭建 1. 用户注册1.1 注册接口基本代码编写1.2 注册接口参数校验 2. 用户登录2.1 登录接口基本代码编写2.2 登录认证2.2.1 登录认证引入2.2.2 JWT 简介2.2.3 登录功能集成 JWT2.2.4 拦截器 3. 获取用户详细信息3.1 获取用户详细信息基本…

openEuler安装KVM

1、关闭防火墙和selinux [rootlocalhost ~]# systemctl stop firewalld[rootlocalhost ~]# setenforce 0 2、下载软件包 libvirt:用于管理虚拟化平台的开源的 API,后台程序和管理工具。 qemu:开源(模拟)软件&#…

【MySQL】InnoDB 什么情况下会产生死锁

🍎个人博客:个人主页 🏆个人专栏:数据库 ⛳️ 功不唐捐,玉汝于成 目录 前言 正文 结语 我的其他博客 前言 在数据库管理系统中,特别是使用 InnoDB 存储引擎的 MySQL 中,死锁是一个可能影响…

Win10下在Qt项目中配置SQlite3环境

资源下载 官网资源:SQLite Download Page 1、sqlite.h sqlite-amalgamation-3450000.zip (2.60 MiB) 2、sqlite3.def,sqlite3.dll sqlite-dll-win-x64-3450000.zip (1.25 MiB) 3、 win10下安装sqlite3所需要文件 sqlite-tools-win-x64-3450000.zipht…

万界星空科技MES系统的生产管理流程

对于生产型工厂来说,车间生产流程无疑是最重要的管理环节,繁琐的生产细节让企业很难找到合理的生产管理方法,导致人工效率低、错误多、成本高。如果想要解决这些问题,工厂就必须要有一套自己的生产管理系统,这样才能提…

【Leetcode】410. 分割数组的最大值

文章目录 题目思路1.max_element2.partial_sum3.upper_bound4.distance 代码运行结果 题目 题目链接 给定一个非负整数数组 nums 和一个整数 k ,你需要将这个数组分成 k 个非空的连续子数组。 设计一个算法使得这 k 个子数组各自和的最大值最小。 示例1&#xff1…

玩转 SpEL 表达式

本文概览 欢迎阅读本文,其中我们将深入探讨 Spring Expression Language(SpEL)的语法和实际应用。从基础概念到高级用法,我们将在本文中了解如何使用 SpEL 提高代码的灵活性和表达力。无论大家是初学者还是有经验的开发者&#x…

ACM题解Day2|1.台风, 2.式神考试,3.DNA,4.方程求解

学习目标: 博主介绍: 27dCnc 专题 : 数据结构帮助小白快速入门 👍👍👍👍👍👍👍👍👍👍👍👍 ☆*: .。. o(≧▽≦)…

Python 算法交易实验67 第一次迭代总结

说明 在这里对第一次迭代(2023.7~ 2024.1)进行一些回顾和总结: 回顾: 1 实现了0~1的变化2 在信息隔绝的条件下,无控制的操作,导致被套 总结: 思路可行,在春暖花开的时候&#x…

企业Oracle1 数据库管理

Oracle的安装 一、基础表的创建 1.1 切换到scott用户 用sys 账户 登录 解锁scott账户 alter user scott account unlock;conn scott/tiger;发现并不存在scott账户,自己创建一个? 查找资料后发现,scott用户的脚本需要自己执行一下 C:\ap…

三、MySQL库表操作

3.1 SQL语句基础(SQL命令) 3.1.1 SQL简介 SQL:结构化查询语言(Structured Query Language),在关系型数据库上执行数据操作,数据检索以及数据维护的标准化语言。使用SQL语句,程序员和数据库管理员可以完成…

Opncv模板匹配 单模板匹配 多模板匹配

目录 问题引入 单模板匹配 ①模板匹配函数: ②查找最值和极值的坐标和值: 整体流程原理介绍 实例代码介绍: 多模板匹配 ①定义阈值 ②zip函数 整体流程原理介绍 实例代码: 问题引入 下面有请我们的陶大郎登场 这张图片是我们的陶大郎,我们接下来将利用陶大郎来介绍…

stm32h7中RTC的BCD模式与BIN模式

RTC的BCD格式与BIN格式 BCD(Binary-Coded Decimal)和BIN(Binary)是两种不同的数字表示格式。 BCD格式: BCD是一种用二进制编码表示十进制数字的格式。在BCD格式中,每个十进制数位使用4位二进制数来表示&am…

c++程序的内存模型,new操作符详解

目录 内存四区 程序运行前 代码区 全局区 程序运行后 栈区 堆区 new操作符 创建一个数 创建一个数组 内存四区 不同区域存放不同的数据,赋予不同的生命周期,让我们更加灵活的编程 程序运行前 程序运行前就有代码区和全局区 代码区 程序编…

Windows系统下使用docker-compose安装mysql8和mysql5.7

windows环境搭建专栏🔗点击跳转 win系统环境搭建(十四)——Windows系统下使用docker安装mysql8和mysql5.7 文章目录 win系统环境搭建(十四)——Windows系统下使用docker安装mysql8和mysql5.7MySQL81.新建文件夹2.创建…

《Linux C编程实战》笔记:信号处理函数的返回

信号处理函数可以正常返回&#xff0c;也可以调用其他函数返回到程序的主函数中&#xff0c;而不是从处理程序返回。 setjmp/longjmp 使用longjmp可以跳转到setjmp设置的位置 这两个函数原型如下 #include<setjmp.h> int setjmp(jmp_buf env); void longjmp(jmp_buf …

QQ数据包解密

Windows版qq数据包格式&#xff1a; android版qq数据包格式&#xff1a; 密钥&#xff1a;16个0 算法&#xff1a;tea_crypt算法 pc版qq 0825数据包解密源码&#xff1a; #include "qq.h" #include "qqcrypt.h" #include <WinSock2.h> #include…

构建库函数雏形(以GPIO为例)

构建库函数雏形 进行外设结构体定义构建置位和复位函数进行库函数的自定义 step I&#xff1a; \textbf{step I&#xff1a;} step I&#xff1a; 对端口进行输出数据类型枚举 step II&#xff1a; \textbf{step II&#xff1a;} step II&#xff1a;对端口进行结构化描述 step…

线性代数的学习和整理23:用EXCEL和python 计算向量/矩阵的:内积/点积,外积/叉积

目录 1 乘法 1.1 标量乘法(中小学乘法) 1.1.1 乘法的定义 1.1.2 乘法符合的规律 1.2 向量乘法 1.2.1 向量&#xff1a;有方向和大小的对象 1.2.2 向量的标量乘法 1.2.3 常见的向量乘法及结果 1.2.4 向量的其他乘法及结果 1.2.5 向量的模长&#xff08;长度&#xff0…

第三篇【传奇开心果系列】Vant开发移动应用:财务管理应用

传奇开心果博文系列 系列博文目录Vant开发移动应用系列博文 博文目录一、项目目标二、编程思路三、初步实现示例代码四、扩展思路五、使用Firebase等后端服务来实现用户认证和数据存储示例代码六、用Vant组件库实现收入和支出分类管理的示例代码七、用Vant组件库实现收入和支出…