MNIST手写字符分类-卷积

MNIST手写字符分类-卷积

文章目录

  • MNIST手写字符分类-卷积
    • 1 模型构造
    • 2 训练
    • 3 推理
    • 4 导出
    • 5 onnx测试
    • 6 opencv部署
    • 7 总结

  在上一篇中,我们介绍了如何在pytorch中使用线性层+ReLU非线性层堆叠的网络进行手写字符识别的网络构建、训练、模型保存、导出和推理测试。本篇文章中,我们将要使用卷积层进行网络构建,并完成后续的训练、保存、导出,并使用opencv在C++中推理我们的模型,将结果可视化。

1 模型构造

  在pytorch中,卷积层的使用比较方便,需要注意的是卷积层的输入通道数、输出通道数、卷积核的大小等参数。这里直接放出构建的网络结构:

import torch
from torch import nn
from torch.utils.data import DataLoader
class ZKNNNet_Conv(nn.Module):def __init__(self):super(ZKNNNet_Conv, self).__init__()self.conv_stack = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(12*12*64, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):logits = self.conv_stack(x)return logits

在这里插入图片描述

从图中可以看出,该模型先堆叠了两个卷积层与ReLU单元,经过最大池化之后,展开并进行后续的全连接层训练。

2 训练

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_Conv
import os
# Download training data from open datasets.
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = ZKNNNet_Conv()
if os.path.exists("./model/model_conv.pth"):model.load_state_dict(torch.load("./model/model_conv.pth"))
model = model.to(device)
print(model)# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# Loss function
loss_fn = nn.CrossEntropyLoss()# Train
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# Test
def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return correctepochs = 200
maxAcc = 0
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)currentAcc = test(test_dataloader, model)if maxAcc < currentAcc:maxAcc = currentAcctorch.save(model.state_dict(), "./model/model_conv.pth")
print("Done!")

模型的训练代码与上一篇中的线性连接训练代码是一样的。
训练过程来看,使用卷积层,在相同数据集上训练,模型收敛速度比用线性层快很多。最终精度达到97.8%。

3 推理

模型训练完成之后,推理过程与上一篇一致,这里简单放一下推理代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_Convimport matplotlib.pyplot as plt# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))# Load the trained model
model = ZKNNNet_Conv()
model.load_state_dict(torch.load("./model/model_conv.pth"))
model.to(device)
model.eval()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)# Perform inference
with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i].cpu()label = labels[i].cpu()prediction = predicted[i].cpu()plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()accuracy = 100 * correct / totalprint("Accuracy on test set: {:.2f}%".format(accuracy))

4 导出

模型导出方式与上一篇一致。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer,ZKNNNet_Conv
model_conv = ZKNNNet_Conv()
if os.path.exists('./model/model_conv.pth'):model_conv.load_state_dict(torch.load('./model/model_conv.pth'))
model_conv = model_conv.to(device)
model_conv.eval()
torch.onnx.export(model_conv,torch.randn(1,1,28,28),'./model/model_conv.onnx',verbose=True)

5 onnx测试

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasetsimport matplotlib.pyplot as pltfrom PIL import Imagesess = rt.InferenceSession("model/model_conv.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.numpy()labels = labels.numpy()outputs = sess.run(None,{input_name:images})[0]outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i]label = labels[i]plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()

至此,模型已经成功的转换成onnx模型,可以用于后续各种部署环境的部署。

6 opencv部署

本例中,使用C++/opencv来尝试部署刚才训练的模型。输入为在之前的博文中提到的将MNIST测试集导出成png图片保存。

#include "opencv2/opencv.hpp"#include <iostream>
#include <filesystem>
#include <string>
#include <vector>int main(int argc, char** argv)
{if (argc != 3){std::cerr << "Usage: MNISTClassifier_onnx_opencv <onnx_model_path> <image_path>" << std::endl;return 1;}cv::dnn::Net net = cv::dnn::readNetFromONNX(argv[1]);if (net.empty()){std::cout << "Error: Failed to load ONNX file." << std::endl;return 1;}std::filesystem::path srcPath(argv[2]);for (auto& imgPath : std::filesystem::recursive_directory_iterator(srcPath)){if(!std::filesystem::is_regular_file(imgPath))continue;const cv::Mat image = cv::imread(imgPath.path().string(), cv::IMREAD_GRAYSCALE);if (image.empty()){std::cerr << "Error: Failed to read image file." << std::endl;continue;}const cv::Size size(28, 28);cv::Mat resized_image;cv::resize(image, resized_image, size);cv::Mat float_image;resized_image.convertTo(float_image, CV_32F, 1.0 / 255.0);cv::Mat input_blob = cv::dnn::blobFromImage(float_image);net.setInput(input_blob);cv::Mat output = net.forward();cv::Point classIdPoint;double confidence;cv::minMaxLoc(output.reshape(1, 1), nullptr, &confidence, nullptr, &classIdPoint);const int class_id = classIdPoint.x;std::cout << "Class ID: " << class_id << std::endl;std::cout << "Confidence: " << confidence << std::endl;cv::Mat bigImg;cv::resize(image,bigImg,cv::Size(128,128));auto parentPath = imgPath.path().parent_path();auto label = parentPath.filename().string()+std::string("<->")+std::to_string(class_id);cv::putText(bigImg, label, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255), 1);cv::imshow("img",bigImg);cv::waitKey();}return 0;
}

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

7 总结

使用卷积神经网络进行MNIST手写字符识别,在模型结构无明显复杂的情况下,模型收敛速度较全连接层构建的网络收敛速度快。

按照相同的套路导出成onnx模型之后,直接通过opencv可以部署,简化深度学习算法部署的难度。

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/26402.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis高性能原理:Redis为什么这么快?

目录 前言&#xff1a; 一、Redis知识系统观 二、Redis为什么这么快&#xff1f; 三、Redis 唯快不破的原理总结 四、Redis6.x的多线程 前言&#xff1a; Redis 为了高性能&#xff0c;从各方各面都进行了优化。学习一门技术&#xff0c;通常只接触了零散的技术点&#xff…

解决linux jenkins要求JDK版本与项目版本JDK不一致问题

背景–问题描述&#xff1a; 新入职公司&#xff0c;交接人说jenkins运行有问题&#xff0c;现在都是手动发布&#xff0c;具体原因让我自己看&#xff08;笑哭&#xff09;。我人都蒙了&#xff0c;测试环境都手动发布&#xff0c;那不是麻烦的要死&#xff01; 接手后&am…

推荐几款短链接工具系统软件

1、C1N短网址(c1n.cn) 为了提升你的品牌并吸引新的受众&#xff0c;C1N短网址可以帮助你以最简单的方式进行科学分析、决策和促进变革。帮助您真正了解客户并促进转型&#xff0c;C1N短网址&#xff0c;它不仅是一种工具&#xff0c;也是一种专业服务。该品牌成立于2018年&…

引入tinyMCE富文本框在vue3中的使用

实现效果&#xff1a; 官网地址&#xff1a;TinyMCE 7 Documentation | TinyMCE Documentation 1.下载依赖&#xff08;我使用的版本是5.0 目前最新版本到7了&#xff09; pnpm/npm install tinymce5.0.0 -S pnpm/npm install tinymce/tinymce-vue -S 2.在public文件夹下…

数字化制造案例分享以及数字化制造能力评估(34页PPT)

资料介绍&#xff1a; 通过全面的数字化企业平台和智能制造技术的应用&#xff0c;制造型企业不仅提升了自身的竞争力&#xff0c;也为整个制造业的数字化转型提供了借鉴。同时&#xff0c;数字化制造能力的评估是企业实现数字化转型的关键环节&#xff0c;需要从技术变革、组…

Javaweb05-会话技术(cookie,session)

会话及会话技术 **概念&#xff1a;**在web开发中&#xff0c;服务器跟踪用户的技术为会话技术 Cookie对象 1.Cookie的工作流程 cookie可以将会话中的数据保存在浏览器中&#xff0c;通过在响应中添加Set-Cookie头字段将数据保存在自身的缓存中去cookie由浏览器创建cookie在…

【学习笔记】C++每日一记[20240612]

给定两个有序的数组&#xff0c;计算两者的交集 给定两个有序整型数组&#xff0c;数组中 的元素是递增的&#xff0c;且各数组中没有重复元素。 第一时间解法&#xff1a;通过一个循环扫描array_1中的每一个元素&#xff0c;然后利用该元素去比较array_2中的每一个元素&…

采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性.

目录 采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性. 采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性. 采用沙普利值(Shapley value)实现数据供给方报酬分配的公平性,在交易模型中考虑参与个体的异质性与隐私保护,主要体现在以下几个方面…

遥控器无法点击AOSP Settings 的管理存储按钮 MANAGE STORAGE

前言 这里是遇到了MANAGE STORAGE的按钮使用遥控器移动的时候无法聚焦到这个按钮&#xff0c;自然也就无法点击。它只能聚焦到这一整个整体&#xff0c;因此我就设置当点击到这一整个整体时&#xff0c;就相应MANAGE STORAGE按钮的点击事件。 图片 代码 packages/apps/Setti…

探索在线问诊系统的安全性与隐私保护

随着远程医疗的普及&#xff0c;在线问诊系统成为医疗服务的重要组成部分。然而&#xff0c;随着医疗数据的在线传输和存储&#xff0c;患者的隐私保护和数据安全面临巨大挑战。本文将探讨在线问诊系统的安全性与隐私保护&#xff0c;介绍常见的安全措施和技术实现&#xff0c;…

图片转Excel表格:提升数据处理效率的利器

在日常工作和生活中&#xff0c;我们经常遇到各种数据和信息以图片的形式存在。有时&#xff0c;这些数据图片中包含了重要的表格信息&#xff0c;例如财务报告、统计数据或调研结果。为了对这些数据进行进一步的分析和处理&#xff0c;我们需要将其转换为可编辑的电子表格格式…

node 版本控制

官网下载 nvm 包 查看node和npm版本&#xff1a;https://github.com/coreybutler/nvm-windows/releases 2、查看nvm是否安装成功 nvm3、基本使用 1、查看当前node可用版本 nvm ls2、查看当前使用的node版本 nvm current3、安装指定node版本 nvm install 19.9.04、切换版…

构建 deno/fresh 的 docker 镜像

众所周知, 最近 docker 镜像的使用又出现了新的困难. 但是不怕, 窝们可以使用曲线救国的方法: 自己制作容器镜像 ! 下面以 deno/fresh 举栗, 部署一个简单的应用. 目录 1 创建 deno/fresh 项目2 构建 docker 镜像3 部署和测试4 总结与展望 1 创建 deno/fresh 项目 执行命令…

LeetCode 算法: 旋转图像c++

原题链接&#x1f517;&#xff1a; 旋转图像 难度&#xff1a;中等⭐️⭐️ 题目 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图…

5.冒泡+选择+插入+希尔

一、排序算法 排序基础 1.排序算法的稳定性 2.何为原地排序算法 1.冒泡排序 从下面第一个元素开始往上冒泡&#xff0c;一轮冒泡下来&#xff0c;最大的元素就去了最上面了 步骤&#xff1a;无序数组 每次冒泡都可以将最大的元素放到最右边去 第一轮比较了5次&#xff1a;…

开箱机特点与操作因素:深入剖析影响效率的关键因素

在现代化物流和生产流程中&#xff0c;开箱机作为一种自动化、高效率的设备&#xff0c;正逐渐成为企业提升工作效率、降低人工成本的得力助手。然而&#xff0c;要想充分发挥开箱机的性能优势&#xff0c;就必须深入了解其特点与操作因素&#xff0c;并准确把握影响效率的关键…

Navicat for MySQL 11软件下载附加详细安装教程

根据使用者情况表明Navicat Premium 能使你快速地在各种数据库系统间传输数据&#xff0c;或传输到一份指定 SQL 格式和编码的纯文本文件&#xff0c;计划不同数据库的批处理作业并在指定的时间运行&#xff0c;其他功能包括导入向导、导出向导、查询创建工具、报表创建工具、数…

Idea | Idea提交.properties文件乱码问题

这里 Transparent natice-to-ascii conversion 自动转换ASCII码 千万别勾选

第 5 章:面向生产的 Spring Boot

在 4.1.2 节中&#xff0c;我们介绍了 Spring Boot 的四大核心组成部分&#xff0c;第 4 章主要介绍了其中的起步依赖与自动配置&#xff0c;本章将重点介绍 Spring Boot Actuator&#xff0c;包括如何通过 Actuator 提供的各种端点&#xff08;endpoint&#xff09;了解系统的…

优雅迷人的小程序 UI 风格

优雅迷人的小程序 UI 风格