实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as npimport einops.layers.torch as elt#载入数据x_train = np.load("../dataset/mnist/x_train.npy")y_train_label = np.load("../dataset/mnist/y_train_label.npy")x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()#前置的特征提取模块self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层nn.ReLU(),nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层nn.ReLU(),nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层)#最终分类器层self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),nn.ReLU(),nn.Conv2d(12,24,kernel_size=5),nn.ReLU(),nn.Conv2d(24,6,kernel_size=3))self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):train_num = len(x_train)//128train_loss = 0.for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizex_batch = torch.tensor(x_train[start:end]).to(device)y_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(x_batch)loss = loss_fn(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/66143.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vmware 16增加硬盘容量并在Ubuntu 18.04上边格式化并挂载

参考了《增加 VM虚拟机硬盘容量》 《Linux学习之分区挂载》中有给VMWare 16虚拟机添加一块硬盘的内容,需要先参考添加硬盘。 sudo mkfs.ext4 /dev/sda4给/dev/sda4进行ext4格式化。 sudo mkdir /mountsda4新建一个挂载目录。 sudo mount -t ext4 /dev/sda4 /mo…

AMBEO 双声道空间音频现已迈进直播制作领域

图片来源:Unsplash,作者:Bence Balla-Schottner AMBEO 双声道空间音频现已迈进直播制作领域 为所有观众解锁更加身临其境的听觉体验 森海塞尔将功能强大的 AMBEO 双声道空间音频技术引入了广播电视直播应用领域,对所有体育赛事广…

AD16 基础应用技巧(一些 “偏好“ 设置)

1. 修改铺铜后自动更新铺铜 AD16 铺铜 复制 自动变形 偏好设置 将【DXP】中的【参数选择】。 将【PCB Editor】中的【General】,然后勾选上【Repour Polygons After Modification】。 2. PCB直角走线处理与T型滴泪 一些没用的AD技巧——AD PCB直角走线处理与…

seq2seq与引入注意力机制的seq2seq

1、什么是 seq2seq? 就是字面意思,“句子 到 句子”。比如翻译。 2、seq2seq 有一些特点 seq2seq 的整体架构是 “编码器-解码器”。 其中,编码器是 RNN,并将 最后一个hidden state(隐藏状态)【即&…

【自用】西门子s7-200连接显示屏和物联网盒子完整配置过程

总览 1.PLC配置 2.显示屏配置 3.物联网盒子配置 一、PLC配置 1.连接PLC软件 STEP-7MicroWIN V4.0 SP9完整版 链接:https://pan.baidu.com/s/17LMEXnbkQZMPI8Bte24Eug?pwdjsi3 提取码:jsi3 2.PLC配置 打开 PLC 上面的小盖子,把红色按钮…

前端:html实现页面切换、顶部标签栏,类似于浏览器的顶部标签栏(完整版)

效果 代码 <!DOCTYPE html> <html><head><style>/* 左侧超链接列表 */.link {display: block;padding: 8px;background-color: #f2f2f2;cursor: pointer;}/* 顶部标签栏 */#tabsContainer {width:98%;display: flex;align-items: center;overflow-x: …

简易虚拟培训系统-UI控件的应用4

目录 Slider组件的常用参数 示例-使用Slider控制主轴 示例-Slider控制溜板箱的移动 本文以操作面板为例&#xff0c;介绍使用Slider控件控制开关和速度。 Slider组件的常用参数 Slider组件下面包含了3个子节点&#xff0c;都是Image组件&#xff0c;负责Slider的背景、填充区…

Ubuntu下的QT开发

ubuntu安装QT的组件如下&#xff1a; 若要在ubuntu下启动QT有两种方案&#xff0c;一种是在菜单栏搜索qt双QT Create&#xff1b;另一种则是使用命令&#xff1a;/opt/Qt5.12.9/Tools/QtCreator/bin/qtcreator.sh

pytest自动化测试两种执行环境切换的解决方案

目录 一、痛点分析 方法一&#xff1a;Hook方法pytest_addoption注册命令行参数 1、Hook方法注解 2、使用方法 方法二&#xff1a;使用插件pytest-base-url进行命令行传参 一、痛点分析 在实际企业的项目中&#xff0c;自动化测试的代码往往需要在不同的环境中进行切换&am…

C语言每日一练--------Day(11)

本专栏为c语言练习专栏&#xff0c;适合刚刚学完c语言的初学者。本专栏每天会不定时更新&#xff0c;通过每天练习&#xff0c;进一步对c语言的重难点知识进行更深入的学习。 今日练习题关键字&#xff1a;找到数组中消失的数字 哈希表 &#x1f493;博主csdn个人主页&#xff…

Kao框架学习

中间件&#xff1a;洋葱模型 这是官网上给出的示例&#xff0c;从logger依次往下执行&#xff0c;执行到最底层的response往回退&#xff0c;结构很像同心圆的洋葱从外层向内层再由内层向外层。 next表示暂停当前层的代码进入下一层&#xff0c; 当最后一层执行完毕开始回溯&a…

Jenkins清理构建(自动)

需求背景实现方法 Dashboard-->Project-->配置-->General-->Discard old builds # 注意&#xff1a;自动清理构建历史将在下次构建时进行

【Java 动态数据统计图】动态X轴二级数据统计图思路Demo(动态,排序,动态数组(重点推荐:难)九(131)

需求&#xff1a; 1.有一组数据集合&#xff0c;数据集合中的数据为动态&#xff1b; 举例如下&#xff1a; [{province陕西省, city西安市}, {province陕西省, city咸阳市}, {province陕西省, city宝鸡市}, {province陕西省, city延安市}, {province陕西省, city汉中市}, {pr…

【网络安全带你练爬虫-100练】第19练:使用python打开exe文件

目录 一、目标1&#xff1a;调用exe文件 二、目标2&#xff1a;调用exe打开文件 一、目标1&#xff1a;调用exe文件 1、subprocess 模块允许在 Python 中启动一个新的进程&#xff0c;并与其进行交互 2、subprocess.run() 函数来启动exe文件 3、subprocess.run(["文件路…

企业应用系统 PHP项目支持管理系统Dreamweaver开发mysql数据库web结构php编程计算机网页

一、源码特点 PHP 项目支持管理系统是一套完善的web设计系统 应用于企业项目管理&#xff0c;从企业内部的各个业务环境总体掌握&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 php项目支撑管理系统2 二、功能介绍 (1)权限管理&#xff1…

手写RPC——数据序列化工具protobuf

手写RPC——数据序列化工具protobuf Protocol Buffers&#xff08;protobuf&#xff09;是一种用于结构化数据序列化的开源库和协议。下面是 protobuf 的一些优点和缺点&#xff1a; 优点&#xff1a; 高效的序列化和反序列化&#xff1a;protobuf 使用二进制编码&#xff0c…

力扣:随即指针138. 复制带随机指针的链表

复制带随机指针的链表 OJ链接 分析&#xff1a; 该题的大致题意就是有一个带随机指针的链表&#xff0c;复制这个链表但是不能指向原链表的节点&#xff0c;所以每一个节点都要复制一遍 大神思路&#xff1a; ps:我是学来的 上代码&#xff1a; struct Node* copyRandomList(s…

TuyaOS开发学习笔记(1)——NB-IoT开发搭建环境、编译烧写(MT2625)

一、搭建环境 1.1 官方资料 TuyaOS 1.2 安装VMware 官网下载&#xff1a;https://customerconnect.vmware.com/en/downloads/info/slug/desktop_end_user_computing/vmware_workstation_pro/16_0 百度网盘&#xff1a;https://pan.baidu.com/s/1oN7H81GV0g6cD9zsydg6vg 提取…

Android 中SettingsActivity(PreferenceFragmentCompat)的简单使用

如果你需要一个简单的APP设置&#xff0c;可以使用sharedPreferences进行存储&#xff0c;我们可以借助AndroidStudio快速创建一个用于设置的Activity&#xff0c;其实它是继承PreferenceFragmentCompat&#xff0c;存储方式用的就是sharedPreferences&#xff0c;只是帮我们节…

科技资讯|微软获得AI双肩包专利,Find My防丢背包大火

根据美国商标和专利局&#xff08;USPTO&#xff09;近日公示的清单&#xff0c;微软于今年 5 月提交了一项智能双肩包专利&#xff0c;其亮点在于整合了 AI 技术&#xff0c;可以识别佩戴者周围环境、自动响应用户聊天请求、访问基于云端的信息、以及和其它设备交互。 在此附…