第P1周:Pytorch实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

1. 实现pytorch环境配置
2. 实现mnist手写数字识别
3. 自己写几个数字识别试试

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架:

(二)具体步骤
**1.**配置Pytorch环境

打开官网PyTorch,Get started:
image.png
接下来是选择安装版本,最难的就是确定Compute Platform的版本,是否要使用GPU。所以先要确定CUDA的版本。
image.png
会发现,pytorch官网根本没有对应12.7的版本,先安装最新的试试呗,选择12.4:
image.png
安装命令:pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
image.png
image.png
安装完成,我们建立python文件,输入如下代码:

import torch  
x = torch.rand(5, 3)  
print(x)  print(torch.cuda.is_available())---------output---------------
tensor([[0.3952, 0.6351, 0.3107],[0.8780, 0.6469, 0.6714],[0.4380, 0.0236, 0.5976],[0.4132, 0.9663, 0.7576],[0.4047, 0.4636, 0.2858]])
True

从输出来看,成功了。下面开始正式的mnist手写数字识别

2. 下载数据并加载数据
import torch  
import torch.nn as nn  
# import matplotlib.pyplot as plt  
import torchvision  # 第一步:设置硬件设备,有GPU就使用GPU,没有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 第二步:导入数据  
# MNIST数据在torchvision.datasets中,自带的,可以通过代码在线下载数据。  
train_ds = torchvision.datasets.MNIST(root='./data',    # 下载的数据所存储的本地目录  train=True,       # True为训练集,False为测试集  transform=torchvision.transforms.ToTensor(),  # 将下载的数据直接转换成张量格式  download=True     # True直接在线下载,且下载到root指定的目录中,注意已经下载了,第二次以后就不会再下载了  )  
test_ds = torchvision.datasets.MNIST(root='./data',  train=False,  transform=torchvision.transforms.ToTensor(),  download=True  )  # 第三步:加载数据  
# Pytorch使用torch.utils.data.DataLoader进行数据加载  
batch_size = 32  
train_dl = torch.utils.data.DataLoader(dataset=train_ds, # 要加载的数据集  batch_size=batch_size, # 批次的大小  shuffle=True,     # 每个epoch重新排列数据  # 以下的参数有默认值可以不写  num_workers=0, # 用于加载的子进程数,默认值为0.注意在windows中如果设置非0,有可能会报错  pin_memory=True, # True-数据加载器将在返回之前将张量复制到设备/CUDA 固定内存中。 如果数据元素是自定义类型,或者collate_fn返回一个自定义类型的批次。  drop_last=False, #如果数据集大小不能被批次大小整除,则设置为 True 以删除最后一个不完整的批次。 如果 False 并且数据集的大小不能被批大小整除,则最后一批将保留。 (默认值:False)  timeout=0, # 设置数据读取的超时时间 , 超过这个时间还没读取到数据的话就会报错。(默认值:0)  worker_init_fn=None # 如果不是 None,这将在步长之后和数据加载之前在每个工作子进程上调用,并使用工作 id([0,num_workers - 1] 中的一个 int)的顺序逐个导入。(默认:None)  )  # 取一个批次看一下数据格式,数据的shape为[batch_size, channel, height, weight]  
# batch_size是已经设定的32,channel, height和weight分别是图片的通道数,高度和宽度  
images, labels = next(iter(train_dl))  
print(images.shape)

image.png
image.png
看这个图片的shape是torch.size([32, 1, 28, 28]),可以看图MNIST的数据集里的图像我猜应该是单色的(channel=1),28 * 28大小的图片(height=28, weight=28)。
将图片可视化展示出来看看:

# 数据可视化  
plt.figure(figsize=(20, 5)) # 指定图片大小 ,图像大小为20宽,高5的绘图(单位为英寸)  
for i , images in enumerate(images[:20]):  # 维度缩减,npimg = np.squeeze(images.numpy())  # 将整个figure分成2行10列,绘制第i+1个子图  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()

image.png

**3.**构建CNN网络
num_classes = 10 # MNIST数据集中是识别0-9这10个数字,因此是10个类别。class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 特征提取网络self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 第一层卷积,卷积核大小3*3self.pool1 = nn.MaxPool2d(2)    # 池化层,池化核大小为2*2self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 第二层卷积,卷积核大小3*3self.pool2 = nn.MaxPool2d(2)# 分类网络self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 第四步:加载并打印模型
# 将模型转移到GPU中
model = Model().to(device)
summary(model)>)

image.png

4.训练模型
# 第五步:训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)  # 循环训练  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 训练集的大小  num_batches = len(dataloader) # 批次数目  train_loss, train_acc = 0, 0  # 初始化训练损失率和正确率都为0  for X, y in dataloader: # 获取图片及标签  X, y = X.to(device), y.to(device)   # 将图片和标准转换到GPU中  # 计算预测误差  pred = model(X) # 使用CNN网络预测输出pred  loss = loss_fn(pred, y) # 计算预测输出的pred和真实值y之间的差距  # 反向传播  optimizer.zero_grad()   # grad属性归零  loss.backward() # 反向传播  optimizer.step()    # 第一步自动更新  # 记录acc与loss  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 测试函数,注意测试函数不需要进行梯度下降,不进行网络权重更新,所以不需要传入优化器  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset)  num_batches = len(dataloader)  test_loss, test_acc = 0, 0  # 当不进行训练时,停止梯度更新,节省计算内存消耗  with torch.no_grad():  for imgs, targets in dataloader:  imgs, target = imgs.to(device), targets.to(device)  # 计算 loss            target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式训练  
epochs = 5  
train_loss, train_acc, test_loss, test_acc = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  train_acc.append(epoch_train_acc)  test_acc.append(epoch_test_acc)  train_loss.append(epoch_train_loss)  test_loss.append(epoch_test_loss)  template = 'Epoch: {:2d}, Train_acc:{:.1f}%, Train_loss: {:.3f}%, Test_acc: {:.1f}%, Test_loss: {:.3f}%'  print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  
print('Done')

image.png

# 可见化一下训练结果  
warnings.filterwarnings("ignore")  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 显示中文不标签,不设置会显示中文乱码  
plt.rcParams['axes.unicode_minus'] = False      # 显示负号  
plt.rcParams['figure.dpi'] = 100                # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  
plt.subplot(1, 2, 1)  plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练与测试正确率')  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练与测试损失率')  plt.show()

image.png

四:预测一下自己手写的数字

准备数据:
image.png
再手动将每个数字切割成单独的一个文件:
image.png
注意,这里并没有将每个图片的大小切割成一致,理论上切割成要求的28*28是最好。我这里用代码来重新生成28 * 28大小的图片。

import torch  
import numpy as np  
from PIL import Image  
from torchvision import transforms  
import torch.nn as nn  
import torch.nn.functional as F  
import matplotlib.pyplot as plt  
import os, pathlib  # 第一步:设置硬件设备,有GPU就使用GPU,没有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 定义模型,要把模型搞过来嘛,不然加载模型会出错。  
class Model(nn.Module):  def __init__(self):  super().__init__()  # 特征提取网络  self.conv1 = nn.Conv2d(1, 32, kernel_size=3 ) # 第一层卷积,卷积核大小3*3  self.pool1 = nn.MaxPool2d(2)    # 池化层,池化核大小为2*2  self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 第二层卷积,卷积核大小3*3  self.pool2 = nn.MaxPool2d(2)  # 分类网络  self.fc1 = nn.Linear(1600, 64)  self.fc2 = nn.Linear(64, 10)  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = torch.flatten(x, start_dim=1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  # 加载模型  
model = torch.load('./models/cnn.pth')   
model.eval()  transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.1307,), (0.3081,))  
])  # 导入数据  
data_dir = "./mydata/handwrite"  
data_dir = pathlib.Path(data_dir)  
image_count = len(list(data_dir.glob('*.jpg')))  
print("图片总数量为:", image_count)  plt.rcParams['font.sans-serif'] = ['SimHei']    # 显示中文不标签,不设置会显示中文乱码  
plt.rcParams['axes.unicode_minus'] = False      # 显示负号  
plt.rcParams['figure.dpi'] = 100                # 分辨率  
plt.figure(figsize=(10, 10))  
i = 0  
for input_file in list(data_dir.glob('*.jpg')):  image = Image.open(input_file)  image_resize = image.resize((28, 28))   # 将图片转换成 28*28  image = image_resize.convert('L')  # 转换成灰度图  image_array = np.array(image)  # print(image_array.shape)    # (high, weight)  image = Image.fromarray(image_array)  image = transform(image)  image = torch.unsqueeze(image, 0)   # 返回维度为1的张量  image = image.to(device)  output = model(image)  pred = torch.argmax(output, dim=1)  image = torch.squeeze(image, 0)     # 返回一个张量,其中删除了大小为1的输入的所有指定维度  image = transforms.ToPILImage()(image)  plt.subplot(10, 4, i+1)  plt.tight_layout()  plt.imshow(image, cmap='gray', interpolation='none')  plt.title("实际值:{},预测值:{}".format(input_file.stem[:1], pred.item()))  plt.xticks([])  plt.yticks([])  i += 1  
plt.show()

image.png

准确性很低,40张图片预测准确数量:6,占比:15.0%.。看图片,感觉resize成28*28和转换成灰度图后,图片本身已经失真比较严重了。先把图片像素翻转一下,其实就是反色处理,加上这段代码:
image.png
image.png
准确率上了一个台阶(40张图片预测准确数量:30,占比:75.0%).。但是看图片,还是不清晰。

(三)总结
  1. epochs=5,预测的准确性达到97%,如果增加迭代的次数到10,准确性提升接近到99%。迭代20次则达到99.3,提升不明显。
    image.png
    image.png
  2. batch_size如何从32调整到64,准确性差不太多
    image.png
    image.png
  3. 后续研究图片增强

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/62232.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Seq2Seq模型的发展历史;深层RNN结构为什么出现梯度消失/爆炸问题,Transformer为什么不会;Seq2Seq模型存在问题

目录 Seq2Seq模型的发展历史 改进不足的地方 深层RNN结构为什么出现梯度消失/爆炸问题,Transformer为什么不会 深层RNN结构为什么出现梯度消失/爆炸问题: Transformer为什么不会出现梯度消失/爆炸问题: Seq2Seq模型存在问题 T5模型介绍 Seq2Seq模型的发展历史 序列到…

网络安全技术详解:虚拟专用网络(VPN) 安全信息与事件管理(SIEM)

虚拟专用网络(VPN)详细介绍 虚拟专用网络(VPN)通过在公共网络上创建加密连接来保护数据传输的安全性和隐私性。 工作原理 VPN的工作原理涉及建立安全隧道和数据加密: 隧道协议:使用协议如PPTP、L2TP/IP…

Hive 窗口函数与分析函数深度解析:开启大数据分析的新维度

Hive 窗口函数与分析函数深度解析:开启大数据分析的新维度 在当今大数据蓬勃发展的时代,Hive 作为一款强大的数据仓库工具,其窗口函数和分析函数犹如一把把精巧的手术刀,助力数据分析师们精准地剖析海量数据,挖掘出深…

SCAU期末笔记 - 数据库系统概念

我校使用Database System Concepts,9-12章不考所以跳过,因为课都逃了所以复习很仓促,只准备过一下每一章最后的概念辨析,我也不知道有没有用 第1章 引言 数据库管理系统(DBMS) 由一个互相关联的数据的集合…

Android 12系统源码_窗口管理(九)深浅主题切换流程源码分析

前言 上一篇我们简单介绍了应用的窗口属性WindowConfiguration这个类,该类存储了当前窗口的显示区域、屏幕的旋转方向、窗口模式等参数,当设备屏幕发生旋转的时候就是通过该类将具体的旋转数据传递给应用的、而应用在加载资源文件的时候也会结合该类的A…

河南省的教育部科技查新工作站有哪些?

郑州大学图书馆(Z12):2007年1月被批准设立“教育部综合类科技查新工作站”,同年12月被河南省科技厅认定为河南省省级科技查新机构。主要面向河南省的高校、科研机构、企业提供科技查新、查收查引等服务。 河南大学图书馆&#xf…

Leetcode经典题6--买卖股票的最佳时机

买卖股票的最佳时机 题目描述: 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票。设计一个算法来计算你所能获取的最大利润。…

MCPTT 与BTC

MCPTT(Mission Critical Push-to-Talk)和B-TrunC(宽带集群)是两种关键通信标准,它们分别由不同的组织制定和推广。 MCPTT(Mission Critical Push-to-Talk)标准由3GPP(第三代合作伙伴…

去除账号密码自动赋值时的输入框背景色

问题描述: 前端使用账号密码登录,若在网页保存过当前页面的密码和账号,那么当再次进入该页面,网页会自动的把账号和密码赋到输入框中,而此时输入框是带有背景色的,与周边的白色背景显得很不协调&#xff1…

【Pytorch】torch.reshape与torch.Tensor.reshape区别

问题引入: 在Pytorch文档中,有torch.reshape与torch.Tensor.reshape两个reshape操作,他们的区别是什么呢? 我们先来看一下官方文档的定义: torch.reshape: torch.Tensor.reshape: 解释: 在p…

扫码与短信验证码登录JS逆向分析与Python纯算法还原

文章目录 1. 写在前面2. 扫码接口分析2. 短信接口分析3. 加密算法还原【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Python与爬虫领域研究与开发工作!…

spring6:3容器:IoC

spring6:3容器:IoC 目录 spring6:3容器:IoC3、容器:IoC3.1、IoC容器3.1.1、控制反转(IoC)3.1.2、依赖注入3.1.3、IoC容器在Spring的实现 3.2、基于XML管理Bean3.2.1、搭建子模块spring6-ioc-xml…

【认证法规】安全隔离变压器

文章目录 定义反激电源变压器 定义 安全隔离变压器(safety isolating transformer),通过至少相当于双重绝缘或加强绝缘的绝缘使输入绕组与输出绕组在电气上分开的变压器。这种变压器是为以安全特低电压向配电电路、电器或其它设备供电而设计…

车机端同步outlook日历

最近在开发一个车机上的日历助手,其中一个需求就是要实现手机端日历和车机端日历数据的同步。然而这种需求似乎没办法实现,毕竟手机日历是手机厂商自己带的系统应用,根本不能和车机端实现数据同步的。 那么只能去其他公共的平台寻求一些机会&…

OpenCV-图像阈值

简单阈值法 此方法是直截了当的。如果像素值大于阈值,则会被赋为一个值(可能为白色),否则会赋为另一个值(可能为黑色)。使用的函数是 cv.threshold。第一个参数是源图像,它应该是灰度图像。第二…

力扣300.最长递增子序列

题目描述 题目链接300. 最长递增子序列 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 […

Vue CLI的作用

Vue CLI(Command Line Interface)是一个基于Vue.js的官方脚手架工具,其主要作用是帮助开发者快速搭建Vue项目的基础结构和开发环境。以下是Vue CLI的具体作用: 1、项目模板与快速生成 Vue CLI提供了一系列预设的项目模板&#x…

【蓝桥杯每日一题】扫雷

扫雷 知识点 2024-12-3 蓝桥杯每日一题 扫雷 dfs (bfs也是可行的) 题目大意 在一个二维平面上放置这N个炸雷,每个炸雷的信息有$(x_i,y_i,r_i) $,前两个是坐标信息,第三个是爆炸半径。然后会输入M个排雷火箭&#xff0…

【大数据学习 | 面经】Spark 3.x 中的AQE(自适应查询执行)

Spark 3.x 中的自适应查询执行(Adaptive Query Execution,简称 AQE)通过多种方式提升性能,主要包括以下几个方面: 动态合并 Shuffle 分区(Coalescing Post Shuffle Partitions): 当 …

城电科技 | 光伏景观长廊 打造美丽乡村绿色低碳示范区 光伏景观设计方案

光伏景观长廊是一种结合了光伏发电技术和零碳景观设计的新型公共公共设施,光伏景观长廊顶上的光伏板不仅可以为周边用电设备提供清洁电能,而且还能作为遮阳设施使用,为人们提供一个美丽又实用的休闲娱乐空间。 光伏景观长廊建设对打造美丽乡…