基于Pytorch的CNN手写数字识别

作为深度学习小白,我想把自己学习的过程记录下来,作为实践部分,我会写一个通用框架,并会不断完善这个框架,作为自己的入门学习。因此略过环境搭建和基础知识的步骤,直接从代码实战开始。

一.下载数据集并加载

在这里使用MINST开源数字识别数据集。

首先导入必要的库,设置训练的设备(gpu或cpu),设置训练的轮次(epoch),然后设置数据集train_data、test_data,并使用torchvision的datasets来读取,下载的MINSt数据集被保存在当前路径的dataset文件夹下,对于训练集和测试集分别设置train的参数,最后把它转成tensor张量。

接着对设置好的数据集进行读取,调用了torch.utils.data下的DataLoader,分别读取训练集和测试集,同时设置batch_size,即为每一次读取多少张图片,然后对训练集数据进行展平(通常测试集不需要)。

# 搭建CNN卷积神经网络对MNIST数据集实现数字识别import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import cv2
import matplotlib.pyplot as plt
import numpy as npdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch = 10train_data = datasets.MNIST("./dataset", train=True,download=True,transform=transforms.ToTensor())
test_data = datasets.MNIST("./dataset", train=False, download=True,transform=transforms.ToTensor())train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

二.定义训练网络

其中super().__init__()允许我们调用父类(nn.Module)的方法,

对于卷积操作nn.Conv2d(输入通道数,输出通道数,卷积核尺寸,步长,padding大小)参数如此,因为输入为灰度图,则对于第一个卷积的输入通道数等于1,最后线性层会输出一个包含10个数据的变量,分别代表10个数字(类别)的概率。

然后,我们实例化model为网络的对象,定义损失函数为交叉熵损失函数,使用Adam优化器对参数(model.parameters())进行优化,初始化学习率为0.001,并调用学习率更新器。

class Dight(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(1, 10, 5),  #输入:batch*1*28*28  输出:batch*10*24*24(28 -5 + 1)nn.ReLU(),  #保持shape不变  输出:batch*10*24*24(28 -5 + 1)nn.MaxPool2d(2),   #输入:batch*10*24*24(28 -5 + 1) 输出:batch*10*12*12nn.Conv2d(10, 20, 3),   #输入:batch*10*12*12  输出:batch*20*10*10(12 - 3 + 1)nn.ReLU(),nn.Flatten(),nn.Linear(20*10*10, 500),   #输入:batch2000   输出:batch 500nn.ReLU(),    #保持shape不变nn.Linear(500, 10)  #输入:batch 500  输出:batch 10)def forward(self, x):return self.model(x)model = Dight()
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
loss_fn =  loss_fn.to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
scheduler = StepLR(optimizer, step_size = 5, gamma = 0.5)

三.开始训练

使用model.train()开始训练,使用for循环遍历数据集中的数据(imgs)和标签(targets),对梯度初始化,将数据传入model进行前向传播,并输出前向传播结果(outputs),根据outputs和给定的标签targets计算交叉熵损失loss,根据loss进行反向传播,根据反向传播更新模型参数。

同时,每1000步打印一下当前的步数和loss,用于观察训练进度和效果。

#定义训练方法
def train():#模型训练model.train()train_step = 0for batch_index, (imgs, targets) in enumerate(train_loader):#部署到device上imgs, targets = imgs.to(device), targets.to(device)#梯度初始化为0optimizer.zero_grad()#训练后的结果outputs = model(imgs)#计算损失loss = loss_fn(outputs, targets)   #交叉熵损失,适用于多分类任务,二分类适用于sigmoid#反向传播loss.backward()#参数更新optimizer.step()train_step += 1if train_step % 1000 == 0:print(f"train Epoch: {train_step} , Loss: {loss.item()}")

四.测试方法

我们会使用测试集对网络进行验证,通过model.eval()对模型进行验证,因为验证时不会计算梯度也不算反向传播,所以与训练不同的是需要使用语句with torch.no_grad(),同样的对测试集进行遍历(这里也可以仿照训练时的写法),之后,同样的计算outputs和loss,还会对test_loss和accuracy进行累计,观察网络在测试集的效果

#定义测试方法
def test():#模型验证model.eval()#正确率accuracy = 0.0#测试损失test_loss = 0.0with torch.no_grad():  #不会计算梯度也不会反向传播for imgs, targets in test_loader:#部署到device上imgs, targets = imgs.to(device), targets.to(device)#测试数据outputs = model(imgs)#计算测试损失loss = loss_fn(outputs, targets)test_loss += loss.item()#累计正确的值accuracy += (outputs.argmax(1) == targets).sum().item()test_loss /= len(test_loader)accuracy /= len(test_data)print(f"整体测试集上的损失: {test_loss},准确率 : {accuracy}")

 五.模型保存

调用

torch.save(model, "my_CNN.pth")

print("模型已保存")

即可

整合上面代码

if __name__ == "__main__":#调用方法for epoch in range(1, epoch + 1):print(f"-------------------第{epoch}轮训练开始------------------")train()# 调整学习率scheduler.step()test()torch.save(model, "my_CNN.pth")print("模型已保存")

六.结果测试

创建另一个py文件,输入任意一张数字图片,对图片的数字进行预测(多分类)。

打开image,并将它resize为28*28,如这里使用的3.jpg为

 用torch.load()加载模型

from PIL import Image
import torchvision
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequentialimg_path = "/home/lm/数字识别/picture/3.jpg"
image = Image.open(img_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)),torchvision.transforms.ToTensor()])image = transform(image)class Dight(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(1, 10, 5),  #输入:batch*1*28*28  输出:batch*10*24*24(28 -5 + 1)nn.ReLU(),  #保持shape不变  输出:batch*10*24*24(28 -5 + 1)nn.MaxPool2d(2),   #输入:batch*10*24*24(28 -5 + 1) 输出:batch*10*12*12nn.Conv2d(10, 20, 3),   #输入:batch*10*12*12  输出:batch*20*10*10(12 - 3 + 1)nn.ReLU(),nn.Flatten(),nn.Linear(20*10*10, 500),   #输入:batch2000   输出:batch 500nn.ReLU(),    #保持shape不变nn.Linear(500, 10)  #输入:batch 500  输出:batch 10)def forward(self, x):return self.model(x)model = torch.load("/home/lm/数字识别/my_CNN.pth")image = torch.reshape(image, (1,1,28,28)).to(device)
model.eval()
with torch.no_grad():output = model(image)
print(output)print(output.argmax(1))

最终输出为

tensor([[-14.0138,  -4.8722,  -7.2821, -11.5329,   6.1589,  -8.7089,  -7.8535,
          -6.8521,  -5.4265,  -7.6144]], device='cuda:0')
tensor([4], device='cuda:0')

可以看出模型可以正确预测出图片类别

七.数据集转换

问题

在上一步加载图片时,我们使用了MINST数据集的图片,但是我们下载的MINST数据集的格式是这样的

 数据集介绍

MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的
数字构成,其中50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员。测试集(test set)也是同样比例的手写数字数据,但保证了测试集和训练集
的作者集不相交。

  MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28 × 28 28\times 2828×28的0 − 9 0-90−9的手写数字图片组成。每个图片是黑底
白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。每个元素表示图片对应的数字出现的概率,显然,该向量标签表示的是数字5。

  MNIST数据集下载地址是http://yann.lecun.com/exdb/mnist/,它包含了4 44个部分:

    (1)训练数据集:train-images-idx3-ubyte.gz (9.45 MB,包含60,000个样本)。
    (2)训练数据集标签:train-labels-idx1-ubyte.gz(28.2 KB,包含60,000个标签)。
    (3)测试数据集:t10k-images-idx3-ubyte.gz(1.57 MB ,包含10,000个样本)。
    (4)测试数据集标签:t10k-labels-idx1-ubyte.gz(4.43 KB,包含10,000个样本的标签)。

数据集转换

编写一个脚本把原二进制格式的数据转换成jpg格式,这里先转换100张

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import cv2
import numpy as npwith open("./dataset/MNIST/raw/train-images-idx3-ubyte", "rb") as f:file = f.read()for i in range(1,100):image1 = [int(str(item).encode('ascii'), 16) for item in file[16+784*(i-1) : 16+784*i]]print(image1)image1_np = np.array(image1, dtype = np.uint8).reshape(28, 28, 1)cv2.imwrite(f"./picture/{i}.jpg", image1_np)

最后,可在picture文件夹下找到转换完成的jpg数据,再用它进行结果测试即可

八.总结

本文介绍了一个通用简单的pytorch框架,还有很多不足和缺点,后续会在本系列继续完善框架

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/112820.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于使用 vxe-table 时设置了 show-overflow tooltip 不展示的问题(Dialog 组件和 table 同时使用)

众所周知,vxe-table 是可以支撑万级数据渲染的表格组件,本质上还是用了虚拟滚动的实现。之前一直知道vxe-table, 但是基本没有机会用的上这个组件,最近在开发埋点数据的统计,后端一次性返回了上千条数据,elementui 的 …

【JavaEE】Java的多线程编程基础知识 -- 多线程篇(2)

Java多线程编程基础知识 一、多线程的创建二、Thread类常用的方法和API2.1 Thread 的几个常见的属性2.2 start 启动一个线程2.3 终止一个线程2.4 等待一个线程-join()2.5 线程休眠函数 -sleep() 三、线程状态3.1 观察所有线程的状态3.2 线程状态和线程转移的意义 四、线程安全&…

Redis设计与实现笔记 - 数据结构篇

Redis设计与实现笔记 - 数据结构篇 相信在我们日常使用中,会经常跟 Redis 打交道。数据结构 String、Hash、List、Set 和 ZSet 都是常用的数据类型。对于使用场景,我们可以滔滔不绝地说很多,但是我们从来就没有关心过它们的底层实现&#xf…

【软考-中级】系统集成项目管理工程师-人力资源管理历年案例

持续更新。。。。。。。。。。。。。。。 目录 2019 下 试题三(20分)背诵整理1. 冲突管理的6种方法2. 获取项目人力资源的依据 系列文章 2019 下 试题三(20分) 阅读下列说明,回答问题 1至问题 3,将解答填入答题纸的对应栏内     某公司承接了一个软件…

力扣第37题 解数独 c++ 难~ 回溯

题目 37. 解数独 困难 相关标签 数组 哈希表 回溯 矩阵 编写一个程序,通过填充空格来解决数独问题。 数独的解法需 遵循如下规则: 数字 1-9 在每一行只能出现一次。数字 1-9 在每一列只能出现一次。数字 1-9 在每一个以粗实线分隔的 3x3 宫…

小谈设计模式(29)—访问者模式

小谈设计模式(29)—访问者模式 专栏介绍专栏地址专栏介绍 访问者模式角色分析访问者被访问者 优缺点分析优点将数据结构与算法分离增加新的操作很容易增加新的数据结构很困难4 缺点增加新的数据结构比较困难增加新的操作会导致访问者类的数量增加34 总结…

解决Github Markdown图片显示残缺的问题

title: 解决Github Markdown图片显示残缺的问题 tags: 个人成长 categories:杂谈 在Github存放Markdown文档,如果图片没有存放在Github服务器上,github会尝试生成Github图片缓存,使用Github图片缓存,进行实际的展示。但比较蛋疼的…

2023年中国火焰切割机分类、产业链及市场规模分析[图]

火焰切割机是一种工业设备,用于利用高温火焰对金属材料进行切割和切割加工的过程。这种技术通常在金属切割、切割、焊接和熔化等领域中使用,通过将氧气和燃料混合产生的火焰来加热金属至高温,然后通过氧化反应将金属氧化物吹散,从…

嵌入式mqtt总线架构方案mosquitto+paho

一 mqtt通信模型 MQTT 协议提供一对多的消息发布&#xff0c;可以降低应用程序的耦合性&#xff0c;用户只需要编写极少量的应用代码就能完成一对多的消息发布与订阅&#xff0c;该协议是基于<客户端-服务器>模型&#xff0c;在协议中主要有三种身份&#xff1a;发布者&…

推荐一种更高效的打字输入法——双拼输入法

简介 双拼&#xff08;也称双打&#xff09;是一种建立在拼音输入法基础之上的文字输入方法&#xff0c;可视为全拼的一种改进。它通过将每个汉字拼音的声母和韵母各自映射到某个按键上&#xff0c;使得每个汉字最多用两个按键表示&#xff0c;从而极大地提高了拼音输入法的输…

LLM ReAct: 将推理和行为相结合的通用范式 学习记录

LLM ReAct 什么是ReAct? LLM ReAct 是一种将推理和行为相结合的通用范式,可以让大型语言模型(LLM)根据逻辑推理(Reason),构建完整系列行动(Act),从而达成期望目标。LLM ReAct 可以应用于多种语言和决策任务,例如问答、事实验证、交互式决策等,提高了 LLM 的效率、…

2022年亚太杯APMCM数学建模大赛B题高速列车的优化设计求解全过程文档及程序

2022年亚太杯APMCM数学建模大赛 B题 高速列车的优化设计 原题再现&#xff1a; 2022年4月12日&#xff0c;中国高铁复兴号CR450动车组在开放线上成功实现单车时速435公里&#xff0c;相对速度870公里&#xff0c;创造了高铁动车组列车穿越开放线和隧道速度的世界纪录。新一代…

用python写一个贪吃蛇的程序能运行能用键盘控制

用python写一个贪吃蛇的程序能运行能用键盘控制 1.源码2.运行效果 1.源码 开发库使用&#xff1a;pygame random 直接在终端运行&#xff1a;pip install pygame pycharm安装库&#xff1a;文件-设置-项目-Python 解释器 import pygame import random# 初始化pygame pygame…

2023年中国轮胎模具需求量、竞争格局及行业市场规模分析[图]

轮胎模具是轮胎生产线中的硫化成形装备&#xff0c;是高技术含量、高精度及高附加值的个性化模具产品&#xff0c;尤其是轮胎的花纹、图案、字体以及其他外观特征的成形都依赖于轮胎模具&#xff0c;因此其制造技术难度较高。其主要功能是通过所成型材料&#xff08;主要是橡塑…

最优化:建模、算法与理论(最优性理论2

5.7 约束优化最优性理论应用实例 5.7.1 仿射空间的投影问题 考虑优化问题 min ⁡ x ∈ R n 1 2 ∣ ∣ x − y ∣ ∣ 2 2 , s . t . A x b \min_{x{\in}R^n}\frac{1}{2}||x-y||_2^2,\\ s.t.{\quad}Axb x∈Rnmin​21​∣∣x−y∣∣22​,s.t.Axb 其中 A ∈ R m n , b ∈ R m …

2024免费的苹果电脑杀毒软件cleanmymac X

苹果电脑怎么杀毒&#xff1f;这个问题自从苹果电脑变得越来越普及&#xff0c;苹果电脑的安全性问题也逐渐成为我们关注的焦点。虽然苹果电脑的安全性相对较高&#xff0c;但仍然存在着一些潜在的威胁&#xff0c;比如流氓软件窥探隐私和恶意软件等。那么&#xff0c;苹果电脑…

【LeetCode:2316. 统计无向图中无法互相到达点对数 | BFS + 乘法原理】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

uniapp 小程序优惠劵样式

先看效果图 上代码 <view class"coupon"><view class"tickets" v-for"(item,index) in 10" :key"item"><view class"l-tickets"><view class"name">10元优惠劵</view><view cl…

基于Java的图书商城管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09; 代码参考数据库参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&am…

Linux之I2C应用编程

I2C-Tools的交叉编译 tar xvf i2c-tools-4.2.tar.xz 首先解压下压缩包 cd i2c-tools-4.2 进入 i2c-tools-4.2目录 make USE_STATIC_LIB1 执行 make 将i2cset ,i2cget ,i2cdump,i2cdetect,i2ctransfer放到板子上 命令直接操作IIC设备 命令行直接操作iic向AP3216C传感器获取数据…