MNIST手写字符分类-卷积

MNIST手写字符分类-卷积

文章目录

  • MNIST手写字符分类-卷积
    • 1 模型构造
    • 2 训练
    • 3 推理
    • 4 导出
    • 5 onnx测试
    • 6 opencv部署
    • 7 总结

  在上一篇中,我们介绍了如何在pytorch中使用线性层+ReLU非线性层堆叠的网络进行手写字符识别的网络构建、训练、模型保存、导出和推理测试。本篇文章中,我们将要使用卷积层进行网络构建,并完成后续的训练、保存、导出,并使用opencv在C++中推理我们的模型,将结果可视化。

1 模型构造

  在pytorch中,卷积层的使用比较方便,需要注意的是卷积层的输入通道数、输出通道数、卷积核的大小等参数。这里直接放出构建的网络结构:

import torch
from torch import nn
from torch.utils.data import DataLoader
class ZKNNNet_Conv(nn.Module):def __init__(self):super(ZKNNNet_Conv, self).__init__()self.conv_stack = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(12*12*64, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):logits = self.conv_stack(x)return logits

在这里插入图片描述

从图中可以看出,该模型先堆叠了两个卷积层与ReLU单元,经过最大池化之后,展开并进行后续的全连接层训练。

2 训练

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_Conv
import os
# Download training data from open datasets.
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = ZKNNNet_Conv()
if os.path.exists("./model/model_conv.pth"):model.load_state_dict(torch.load("./model/model_conv.pth"))
model = model.to(device)
print(model)# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# Loss function
loss_fn = nn.CrossEntropyLoss()# Train
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# Test
def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return correctepochs = 200
maxAcc = 0
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)currentAcc = test(test_dataloader, model)if maxAcc < currentAcc:maxAcc = currentAcctorch.save(model.state_dict(), "./model/model_conv.pth")
print("Done!")

模型的训练代码与上一篇中的线性连接训练代码是一样的。
训练过程来看,使用卷积层,在相同数据集上训练,模型收敛速度比用线性层快很多。最终精度达到97.8%。

3 推理

模型训练完成之后,推理过程与上一篇一致,这里简单放一下推理代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_Convimport matplotlib.pyplot as plt# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))# Load the trained model
model = ZKNNNet_Conv()
model.load_state_dict(torch.load("./model/model_conv.pth"))
model.to(device)
model.eval()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)# Perform inference
with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i].cpu()label = labels[i].cpu()prediction = predicted[i].cpu()plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()accuracy = 100 * correct / totalprint("Accuracy on test set: {:.2f}%".format(accuracy))

4 导出

模型导出方式与上一篇一致。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer,ZKNNNet_Conv
model_conv = ZKNNNet_Conv()
if os.path.exists('./model/model_conv.pth'):model_conv.load_state_dict(torch.load('./model/model_conv.pth'))
model_conv = model_conv.to(device)
model_conv.eval()
torch.onnx.export(model_conv,torch.randn(1,1,28,28),'./model/model_conv.onnx',verbose=True)

5 onnx测试

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasetsimport matplotlib.pyplot as pltfrom PIL import Imagesess = rt.InferenceSession("model/model_conv.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.numpy()labels = labels.numpy()outputs = sess.run(None,{input_name:images})[0]outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i]label = labels[i]plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()

至此,模型已经成功的转换成onnx模型,可以用于后续各种部署环境的部署。

6 opencv部署

本例中,使用C++/opencv来尝试部署刚才训练的模型。输入为在之前的博文中提到的将MNIST测试集导出成png图片保存。

#include "opencv2/opencv.hpp"#include <iostream>
#include <filesystem>
#include <string>
#include <vector>int main(int argc, char** argv)
{if (argc != 3){std::cerr << "Usage: MNISTClassifier_onnx_opencv <onnx_model_path> <image_path>" << std::endl;return 1;}cv::dnn::Net net = cv::dnn::readNetFromONNX(argv[1]);if (net.empty()){std::cout << "Error: Failed to load ONNX file." << std::endl;return 1;}std::filesystem::path srcPath(argv[2]);for (auto& imgPath : std::filesystem::recursive_directory_iterator(srcPath)){if(!std::filesystem::is_regular_file(imgPath))continue;const cv::Mat image = cv::imread(imgPath.path().string(), cv::IMREAD_GRAYSCALE);if (image.empty()){std::cerr << "Error: Failed to read image file." << std::endl;continue;}const cv::Size size(28, 28);cv::Mat resized_image;cv::resize(image, resized_image, size);cv::Mat float_image;resized_image.convertTo(float_image, CV_32F, 1.0 / 255.0);cv::Mat input_blob = cv::dnn::blobFromImage(float_image);net.setInput(input_blob);cv::Mat output = net.forward();cv::Point classIdPoint;double confidence;cv::minMaxLoc(output.reshape(1, 1), nullptr, &confidence, nullptr, &classIdPoint);const int class_id = classIdPoint.x;std::cout << "Class ID: " << class_id << std::endl;std::cout << "Confidence: " << confidence << std::endl;cv::Mat bigImg;cv::resize(image,bigImg,cv::Size(128,128));auto parentPath = imgPath.path().parent_path();auto label = parentPath.filename().string()+std::string("<->")+std::to_string(class_id);cv::putText(bigImg, label, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255), 1);cv::imshow("img",bigImg);cv::waitKey();}return 0;
}

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

7 总结

使用卷积神经网络进行MNIST手写字符识别,在模型结构无明显复杂的情况下,模型收敛速度较全连接层构建的网络收敛速度快。

按照相同的套路导出成onnx模型之后,直接通过opencv可以部署,简化深度学习算法部署的难度。

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/26402.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北京Web前端薪资揭秘:从行业趋势到个人成长

北京Web前端薪资揭秘&#xff1a;从行业趋势到个人成长 在科技飞速发展的今天&#xff0c;Web前端作为互联网行业的核心岗位之一&#xff0c;其薪资水平一直备受关注。尤其是在北京这样的一线城市&#xff0c;Web前端工程师的薪资更是成为了人们热议的话题。那么&#xff0c;北…

Redis高性能原理:Redis为什么这么快?

目录 前言&#xff1a; 一、Redis知识系统观 二、Redis为什么这么快&#xff1f; 三、Redis 唯快不破的原理总结 四、Redis6.x的多线程 前言&#xff1a; Redis 为了高性能&#xff0c;从各方各面都进行了优化。学习一门技术&#xff0c;通常只接触了零散的技术点&#xff…

解决linux jenkins要求JDK版本与项目版本JDK不一致问题

背景–问题描述&#xff1a; 新入职公司&#xff0c;交接人说jenkins运行有问题&#xff0c;现在都是手动发布&#xff0c;具体原因让我自己看&#xff08;笑哭&#xff09;。我人都蒙了&#xff0c;测试环境都手动发布&#xff0c;那不是麻烦的要死&#xff01; 接手后&am…

代码随想录算法训练营第36天 [860.柠檬水找零 406.根据身高重建队列 452. 用最少数量的箭引爆气球 ]

代码随想录算法训练营第36天 [860.柠檬水找零 406.根据身高重建队列 452. 用最少数量的箭引爆气球 ] 一、860.柠檬水找零 链接: 代码随想录. 思路&#xff1a;十块只能找五块&#xff0c;二十能找十块五块和三个五块&#xff0c;优先消耗十块 做题状态&#xff1a;看解析后做出…

推荐几款短链接工具系统软件

1、C1N短网址(c1n.cn) 为了提升你的品牌并吸引新的受众&#xff0c;C1N短网址可以帮助你以最简单的方式进行科学分析、决策和促进变革。帮助您真正了解客户并促进转型&#xff0c;C1N短网址&#xff0c;它不仅是一种工具&#xff0c;也是一种专业服务。该品牌成立于2018年&…

引入tinyMCE富文本框在vue3中的使用

实现效果&#xff1a; 官网地址&#xff1a;TinyMCE 7 Documentation | TinyMCE Documentation 1.下载依赖&#xff08;我使用的版本是5.0 目前最新版本到7了&#xff09; pnpm/npm install tinymce5.0.0 -S pnpm/npm install tinymce/tinymce-vue -S 2.在public文件夹下…

数字化制造案例分享以及数字化制造能力评估(34页PPT)

资料介绍&#xff1a; 通过全面的数字化企业平台和智能制造技术的应用&#xff0c;制造型企业不仅提升了自身的竞争力&#xff0c;也为整个制造业的数字化转型提供了借鉴。同时&#xff0c;数字化制造能力的评估是企业实现数字化转型的关键环节&#xff0c;需要从技术变革、组…

Javaweb05-会话技术(cookie,session)

会话及会话技术 **概念&#xff1a;**在web开发中&#xff0c;服务器跟踪用户的技术为会话技术 Cookie对象 1.Cookie的工作流程 cookie可以将会话中的数据保存在浏览器中&#xff0c;通过在响应中添加Set-Cookie头字段将数据保存在自身的缓存中去cookie由浏览器创建cookie在…

pydantic 生成 json-schema,导入yapi

python 去除allOf def replace_anyof(data):if isinstance(data, dict):if "allOf" in data:data.update(data.pop("allOf")[0])for _, v in data.items():replace_anyof(v)elif isinstance(data, list):for v in data:replace_anyof(v)else:returnschema…

Axios 二次封装详解

Axios 二次封装详解 在前端开发中&#xff0c;经常需要对网络请求进行一些定制化的处理。以下是关于一个 Axios 二次封装的详细介绍。 我一般把以下代码放到utils目录下创建request.js文件&#xff0c;以下是部分关键代码示例&#xff1a; //引用axios实例 import axios fro…

【学习笔记】C++每日一记[20240612]

给定两个有序的数组&#xff0c;计算两者的交集 给定两个有序整型数组&#xff0c;数组中 的元素是递增的&#xff0c;且各数组中没有重复元素。 第一时间解法&#xff1a;通过一个循环扫描array_1中的每一个元素&#xff0c;然后利用该元素去比较array_2中的每一个元素&…

采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性.

目录 采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性. 采用沙普利值(Shapley value)实现了数据供给方报酬分配的公平性. 采用沙普利值(Shapley value)实现数据供给方报酬分配的公平性,在交易模型中考虑参与个体的异质性与隐私保护,主要体现在以下几个方面…

LeetCode-day11-2813. 子序列最大优雅度

LeetCode-day11-2813. 子序列最大优雅度 题目描述示例示例1&#xff1a;示例2&#xff1a;示例3&#xff1a; 思路代码 题目描述 给你一个长度为 n 的二维整数数组 items 和一个整数 k 。 items[i] [profiti, categoryi]&#xff0c;其中 profiti 和 categoryi 分别表示第 i…

<Python><opencv><TesseractOCR>基于python和opencv,使用ocr识别图片中的文本并进行替换

前言 本文是在python中,利用opencv处理图片,利用tesseractOCR来识别图片中的文本并进行替换的一种实现方法。 环境配置 系统:windows 平台:visual studio code 语言:python 库:pyqt5、opencv、tesseractOCR 代码介绍 本文程序功能实现,主要依赖于tesseractOCR这个库,…

遥控器无法点击AOSP Settings 的管理存储按钮 MANAGE STORAGE

前言 这里是遇到了MANAGE STORAGE的按钮使用遥控器移动的时候无法聚焦到这个按钮&#xff0c;自然也就无法点击。它只能聚焦到这一整个整体&#xff0c;因此我就设置当点击到这一整个整体时&#xff0c;就相应MANAGE STORAGE按钮的点击事件。 图片 代码 packages/apps/Setti…

探索在线问诊系统的安全性与隐私保护

随着远程医疗的普及&#xff0c;在线问诊系统成为医疗服务的重要组成部分。然而&#xff0c;随着医疗数据的在线传输和存储&#xff0c;患者的隐私保护和数据安全面临巨大挑战。本文将探讨在线问诊系统的安全性与隐私保护&#xff0c;介绍常见的安全措施和技术实现&#xff0c;…

图片转Excel表格:提升数据处理效率的利器

在日常工作和生活中&#xff0c;我们经常遇到各种数据和信息以图片的形式存在。有时&#xff0c;这些数据图片中包含了重要的表格信息&#xff0c;例如财务报告、统计数据或调研结果。为了对这些数据进行进一步的分析和处理&#xff0c;我们需要将其转换为可编辑的电子表格格式…

利用python进行批量TIF转NC并进行像元尺度的MK检验

批量TIF转NC并进行MK检验 这里主要记录一个批量进行tif文件转nc,并且将长序列数据进行mk检验的python代码。有问题随时联系:jia5678912。 import os import numpy as np import xarray as xr from osgeo import gdal, osrdef Search_File(dirname,suffix):This function ca…

node 版本控制

官网下载 nvm 包 查看node和npm版本&#xff1a;https://github.com/coreybutler/nvm-windows/releases 2、查看nvm是否安装成功 nvm3、基本使用 1、查看当前node可用版本 nvm ls2、查看当前使用的node版本 nvm current3、安装指定node版本 nvm install 19.9.04、切换版…

基于springboot的人力资源管理系统源码数据库

传统信息的管理大部分依赖于管理人员的手工登记与管理&#xff0c;然而&#xff0c;随着近些年信息技术的迅猛发展&#xff0c;让许多比较老套的信息管理模式进行了更新迭代&#xff0c;员工信息因为其管理内容繁杂&#xff0c;管理数量繁多导致手工进行处理不能满足广大用户的…