卷积神经网络实现MNIST手写数字识别 - P1

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 引用需要的包
      • 设置GPU
    • 数据准备
      • 下载数据集
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

引用需要的包

Python写程序都需要做的事

import torch # 有些API直接在模块下
import torch.nn as nn # 大部分和模型相关的API
import torch.optim as optim # 优化器相关API
# 一些可以直接调用的函数封装(和nn下的很多方法是一样的效果不同的形式)
import torch.nn.functional as F from torch.utils.data import DataLoader # 数据集做分批,随机排序
from torchvision import datasets, transforms # 预置数据集下载,数据增强import matplotlib.pyplot as plt # 图表库
import numpy as np # 用来操作numpy数组,图像展示用from torchinfo import summary # 打开模型结构

设置GPU

首先用一个全局的对象设置一下当前的设备,是使用CPU还是CPU

# 有显卡就用显卡,没有就用CPU
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')

数据准备

下载数据集

调用torchvision包预置的API可以一键下载MNIST数据集

train_dataset = datasets.MNIST(root='data',  # 数据存放位置train=True, # 加载训练集还是验证集download=True,  # 本地没有是否从远程下载transform=transforms.ToTensor()) # 载入后将图像转换成pytorch的tensor对象
test_dataset = datasets.MNIST(root='data',  train=False,  # False说明是验证集download=True,transform=transforms.ToTensor())

数据集预览

先看看数据集中图像的样子,比如是单通道还是三通道,长宽是多少,然后就可以设置缩放以及模型的一些参数

image, label = train_dataset[0]
image.shape

图片信息
结果表明数据集中的图片应该是单通道的高28宽28的图像

打印里面20个图看看是什么样的

plt.figure(figsize=(20, 4)) # 设置一个plt图表画板的宽和高,单位是英寸。。
for i in range(20):image, label = train_dataset[i]plt.subplot(2, 10, i+1) # 以2行10列的形式展示图片# 先把tensor转为了numpy数组,然后把(1, 28, 28)第0维用squeeze去掉# cmap=plt.cm.binary说明是一个单通道的灰度图plt.imshow(np.squeeze(image.numpy()), cmap=plt.cm.binary)plt.title(label) # 打印一下对应的标签plt.axis('off') # 不显示坐标轴

图像预览

数据集准备

设置一下数据的批次大小

batch_size = 32
# 训练集上将数据的顺序打乱一下
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
test_loader= DataLoader(test_dataset, batch_size=batch_size)

模型设计

采用一个类似于LeNet的小型卷积网络

class Model(nn.Module):def __init__(self, num_classes):super().__init__()# 定义两个卷积层,核都是3x3的,通道数递增self.conv1 = nn.Conv2d(1, 16, kernel_size=3)self.conv2 = nn.Conv2d(16, 32, kernel_size=3)# 池化层没有参数需要学习,可以复用一个self.maxpool = nn.MaxPool2d(2)# 全连接层的输入维度要结果计算,可以在forward的时候算一下self.fc1 = nn.Linear(5*5*32128)# 最后一层的输出得是分类的数量self.fc2 = nn.Linear(128, num_classes)def forward(self, x):# 28x28 -> conv1 -> 26x26 -> maxpool -> 13x13x = self.maxpool(F.relu(self.conv1(x)))# 13x13 -> conv2 -> 11x11 -> maxpool -> 5x5x = self.maxpool(F.relu(self.conv2(x)))# 这里要进全连接层了,需要把数据压平,保留第0维,从第1维开始压x = torch.Flatten(start_dim=1)x = F.relu(self.fc1(1))# 最后一层就不加激活函数了x = self.fc2()
# 将模型创建后,设备设置为上面定义的设备对象
model = Model(10).to(device)
# 一定要加input_size,不然打印的就不是实际执行的样子,而是按self中定义的顺序,复用的组件也展示不出来
summary(model, input_size(1, 1, 28, 28))

模型结构

模型训练

接下来就到了训练模型的环节了

超参数设置

需要设置的超参数有训练的轮次epoch和学习率learning_rate

# 轮次
epochs = 10
# 学习率
larning_rate = 0.001
# 创建优化器,将模型参数进去,并设置学习率
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 分类问题,无脑使用交叉熵损失
loss_fn = nn.CrossEntropyLoss()

helper函数

编写两个函数用来封装模型训练和模型验证的过程

  1. 模型训练
def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset) # 训练总数据量num_batches = len(train_loader) # 批次数量train_loss, train_acc = 0, 0 # 记录并返回本次训练过程的状态数据for x, y in train_loader:x, y = x.to(device), y.to(device) # 将数据加载到和模型相同的设备中,不然取不到值preds = model(x) # 这样模型会自动调用forward并进行一些参数的跟踪操作等loss = loss_fn(preds, y) # 计算当前批次的损失optimizer.zero_grad() # 清空之前训练时产生的梯度loss.backward() # 在损失函数上对参数执行反向传播计算梯度optimizer.step() # 执行参数更新操作# 累加当前数据train_loss += loss.item()# 计算正确数需要使用argmax求概率最大的一个分类然后和ground truth比较train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batches # 因为一个批次只计算一次损失,求平均值train_acc /= size # 正确率是在总数上计算的return train_loss, train_loss # 返回数据
  1. 模型验证
# 基本上就是train函数的简化
def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)preds = model(x)loss = loss_fn(preds, y)test_loss += loss.item()test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc

正式训练

开始正式训练,其实也可以封装成一个helper

# 记录训练过程的数据
train_loss, train_acc = [],[]
test_loss, test_acc = [],[]for epoch in range(epochs):model.train() # 切换模型为训练模式epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)model.eval() # 切换模型为评估模式epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)# 记录本轮次数据train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 打印本轮次的数据信息print(f"Epoch:{epoch+1}, Train loss: {epoch_train_loss:.3f}, Train accuracy: {epoch_train_loss*100:.1f}, Validation loss: {epoch_test_loss:.3f}, Validation accuracy: {epoch_test_acc*100:.1f}")

训练过程

结果呈现

上面打印的结果不够直观我们可以用折线图打印一下

plt.figure(figsize=(16, 4))
series = range(epochs)
plt.subplot(1, 2, 1) # 一排两个图表
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练结果


总结与心得体会

通过整个过程可以发现,手写数字的识别还是非常简单的,训练的效率比较快,结果也不错。非常适合拿来练手,学习一些基本概念、深度学习框架和分类任务实践过程等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/29303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Hystrix技术指南】(1)基本使用和配置说明

这世间许多事物皆因相信而存在,所以人们亲手捏出了泥菩萨,却选择坚定的去信仰它。 分布式系统的规模和复杂度不断增加,随着而来的是对分布式系统可用性的要求越来越高。在各种高可用设计模式中,【熔断、隔离、降级、限流】是经常被…

一个好的人力资源管理系统包括哪些部分

阅读本文,您将具体详细了解:一个好的人力资源管理系统应该包括哪些部分。 人事部门是一家公司重要的职能部门之一,为公司的持续性、健康性发展提供人力保障。 然而,目前传统的人事管理方式在应对一些问题时存在着一些挑战。 例…

Shell脚本学习-循环的控制命令

break continue exit对比&#xff1a; 示例1&#xff1a;break命令跳出整个循环。 [rootabc scripts]# cat break1.sh #!/bin/bashfor((i0;i<5;i)) doif [ $i -eq 3 ]thenbreakfiecho $i done echo "ok"[rootabc scripts]# sh break1.sh 0 1 2 ok可以看到i等于3及…

jupyter lab环境配置

1.jupyterlab 使用虚拟环境 conda install ipykernelpython -m ipykernel install --user --name tf --display-name "tf" #例&#xff1a;环境名称tf2. jupyter lab kernel管理 show kernel list jupyter kernelspec listremove kernel jupyter kernelspec re…

微信小程序--原生

1&#xff1a;数据绑定 1&#xff1a;数据绑定的基本原则 2&#xff1a;在data中定义页面的数据 3&#xff1a;Mustache语法 4&#xff1a;Mustache的应用场景 1&#xff1a;常见的几种场景 2&#xff1a;动态绑定内容 3&#xff1a;动态绑定属性 4&#xff1a;三元运算 4&am…

C语言:打开调用堆栈

第一步&#xff1a;打断点 第二步&#xff1a;FnF5 第三步&#xff1a;按如图找到调用堆栈

C 语言高级3--函数指针回调函数,预处理,动态库的封装

目录 1.函数指针和回调函数 1.1 函数指针 1.1.1 函数类型 1.1.2 函数指针(指向函数的指针) 1.1.3 函数指针数组 1.1.4 函数指针做函数参数(回调函数) 2.预处理 2.1 预处理的基本概念 2.2 文件包含指令(#include) 2.2.1 文件包含处理 2.2.2 #incude<>和#include&q…

C++ 线性群体的概念

线性群体中的元素次序与其位置关系是对应的。 在线性群体中&#xff0c;可以按照访问元素的不同方法分为直接访问、顺序访问和索引访问。 &#xff08;1&#xff09;直接访问 对可直接访问的线性群体&#xff0c;我们可以直接访问群体中的任何一个元素&#xff0c;而不必首先访…

npm 报错 cb() never called!

不知道有没有跟我一样的情况&#xff0c;在使用npm i的时候一直报错&#xff1a;cb() never called! 换了很多个node版本&#xff0c;还是不行&#xff0c;无法解决这个问题 百度也只是让降低node版本请缓存&#xff0c;gpt给出的解决方案也是同样的 但是缓存清过很多次了&a…

Python中enumerate用法详解

目录 1.简介 2.语法 3.参数 4.返回值 5.详解 6.实例 7.补充 1.简介 enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列&#xff0c;同时列出数据和数据下标&#xff0c;一般用在 for 循环当中。 2.语法 以下是 enumerate() 方法的语…

Linux 匿名页的生命周期

目录 匿名页的生成 匿名页生成时的状态 do_anonymous_page缺页中断源码 从匿名页加入Inactive lru引出 一个非常重要内核patch 匿名页何时回收 本文以Linux5.9源码讲述 匿名页的生成 用户空间malloc/mmap(非映射文件时&#xff09;来分配内存&#xff0c;在内核空间发生…

【小梦C嘎嘎——启航篇】类和对象(中篇)

【小梦C嘎嘎——启航篇】类和对象&#xff08;中篇&#xff09;&#x1f60e; 前言&#x1f64c;类的6个默认成员函数构造函数析构函数拷贝构造函数拷贝构造函数的特性有哪些&#xff1f;既然编译器可以自动生成一个拷贝构造函数&#xff0c;为什么我们还要自己设计实现呢&…

【腾讯云 Cloud Studio 实战训练营】使用Cloud Studio构建SpringSecurity权限框架

1.Cloud Studio&#xff08;云端 IDE&#xff09;简介 Cloud Studio 是基于浏览器的集成式开发环境&#xff08;IDE&#xff09;&#xff0c;为开发者提供了一个永不间断的云端工作站。用户在使用 Cloud Studio 时无需安装&#xff0c;随时随地打开浏览器就能在线编程。 Clou…

Spring 知识点

Spring 1.1 Spring 简介 1.1.1 Spring 概念 Spring是一个轻量级Java开发框架&#xff0c;最早有Rod Johnson创建为了解决企业级应用开发的业务逻辑层和其他各层的耦合问题Spring最根本的使命是解决企业级应用开发的复杂性&#xff0c;即简化Java开发。使现有的技术更加容易使…

Linux下进程的特点与环境变量

目录 进程的特点 进程特点的介绍 进程时如何实现并发性的 进程间如何切换 概念铺设 PC指针 上下文 环境变量 PATH 修改PATH HOME SHELL env 命令行参数 什么是命令行参数&#xff1f; 打印命令行参数 通过函数获得环境变量 getenv 命令行参数 env 修改环境变…

SpringBoot 项目使用 Redis 对用户 IP 进行接口限流

一、思路 使用接口限流的主要目的在于提高系统的稳定性&#xff0c;防止接口被恶意打击&#xff08;短时间内大量请求&#xff09;。 比如要求某接口在1分钟内请求次数不超过1000次&#xff0c;那么应该如何设计代码呢&#xff1f; 下面讲两种思路&#xff0c;如果想看代码可…

MySql用户管理、权限管理

用户管理 1. 查看系统用户&#xff08;查询mysql系统数据库中的user表&#xff09; select * from mysql.user; 2. 创建用户 CREATE USER 用户名主机名 identified by 密码 -- 创建用户zhonghua,只能在当前主句localhost访问,密码为123456 create user zhonghualocalhost i…

springCache-缓存

SpringCache 简介&#xff1a;是一个框架&#xff0c;实现了基于注解的缓存功能&#xff0c;底层可以切换不同的cache的实现&#xff0c;具体是通过CacheManager接口实现 使用springcache,根据实现的缓存技术&#xff0c;如使用的redis,需要导入redis的依赖包 基于map缓存 …

MySQL 查询语句大全

目录 基础查询 直接查询 AS起别名 去重&#xff08;复&#xff09;查询 条件查询 算术运算符查询 逻辑运算符查询 正则表达式查询⭐ 模糊查询 范围查询 是否非空判断查询 排序查询 限制查询&#xff08;分页查询&#xff09; 随机查询 分组查询 HAVING 高级查询…

EtherCAT转EtherCAT网关FX5U有EtherCAT功能吗两个ETHERCAT设备互联

1.1 产品功能 捷米JM-ECT-ECT是自主研发的一款ETHERCAT从站功能的通讯网关。该产品主要功能是将2个ETHERCAT网络连接起来。 本网关连接到ETHERCAT总线中做为从站使用。 1.2 技术参数 1.2.1 捷米JM-ECT-ECT技术参数 ● 网关做为ETHERCAT网络的从站&#xff0c;可以连接倍福、…