pytorch学习(十二)c++调用minist训练的onnx模型

在实际使用过程中,使用python速度不够快,并且不太好嵌入到c++程序中,因此可以把pytorch训练的模型转成onnx模型,然后使用opencv进行调用。

所需要用到的库有:

opencv

1.完整的程序如下

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import re
from PIL import Imagecur_pwd_path = os.getcwd()def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_besttensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():tensor = tensor.to(device)print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':tensor = tensor.to(torch.device("cpu"))print("tensor is on", tensor.device)
if tensor.device == "cpu":tensor = tensor.to(torch.device("cuda:0"))print("tensor is on", tensor.device)trainning_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])test_model.eval()
test_model.to("cuda")transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)model = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))num_epoches =10
min_loss_val = 100000
Resume = Falsedef train():global min_loss_valstart_epoch = -1if Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()tensorboard_ind =0;for epoch in range(num_epoches):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()tensorboard_ind += 1train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试test_loss_value = 0model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)test_loss_value = test_lossprint("test result:",100 * test_correct,"%  avg loss:",test_loss)scheduler_1.step()#设置checkpointif epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)def test_singleFrame():model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg"img =Image.open(img_path)test_model = MinistNet()test_model = torch.load(model_path)test_model.to("cuda")transform=ToTensor()img = transform(img)img.to("cuda")result = test_model(img)val, index = torch.max(result)print(index)# Press the green button in the gutter to run the script.if __name__ == '__main__':train()#保存onnxmodel.cpu()model.eval()x= torch.randn(1,1,28,28)torch.onnx.export(model,x,"model.onnx")

2.训练并保存模型

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)

3.加载并测试模型

model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])test_model.eval()
test_model.to("cuda")transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)

结果如下:

按照第0,1来数数,tensor([1])刚好就是2.

4.保存onnx模型

if __name__ == '__main__':train()#保存onnxmodel.cpu()model.eval()x= torch.randn(1,1,28,28)torch.onnx.export(model,x,"model.onnx")

5.使用C++加opencv实现minist手写数字的识别

// test_onnm.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//#include<ostream>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn.hpp>#include <iostream>using namespace std;
using namespace cv;
using namespace dnn;int main()
{std::cout << "Hello World!\n";//cv::dnn::Net net = cv::dnn::readTensorFromONNX();cv::dnn::Net net = cv::dnn::readNetFromONNX("E:\\TOOLE\\slam_evo\\pythonProject\\model.onnx");if (net.empty()){std::cout << "加载onnx模型失败" << std::endl;return -1;}net.setPreferableBackend(DNN_BACKEND_OPENCV);net.setPreferableTarget(DNN_TARGET_CPU);cv::Mat img = cv::imread("E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg",cv::IMREAD_GRAYSCALE);if(img.cols != 28 || img.rows != 28){return -1;}cv::Mat blob;float scaleFactor = 1 / 255.0;blobFromImage(img, blob, scaleFactor, Size(), Scalar(), true, false, CV_32F);net.setInput(blob);cv::Mat predict = net.forward();for (int i = 0; i < predict.total(); i++){std::cout << predict.at<float>(i) << "  ";}std::cout << std::endl;double minVal, maxVal;Point minLoc, maxLoc;// 查找最大值和最小值及其位置minMaxLoc(predict, &minVal, &maxVal, &minLoc, &maxLoc);cout << maxVal << "    " << maxLoc.x<<"   "<< maxLoc.y << "\n";return 0;}

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47374.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零基础STM32单片机编程入门(十七)SPI总线详解及RC522-NFC刷卡模块实战含源码

文章目录 一.概要二.SPI总线基本概念1.SPI总线内部框图2.总体特征3.通讯时序 三.RC522介绍1.NFC基本介绍2.RC522模块基本特点3.RC522模块原理图4.RC522模块SPI通讯时序 四.RC522模块读卡实验五.CubeMX工程源代码下载六.小结 一.概要 SPI总线是由Motorola公司提出&#xff0c;是…

05_解封装和解码

1. 基本概念 容器就是一种文件格式&#xff0c;比如flv、mkv、mp4等。包含下面5种流以及文件头信息。 流是一种视频数据信息的传输方式&#xff0c;5种流&#xff1a;音频&#xff0c;视频&#xff0c;字幕&#xff0c;附件&#xff0c;数据。 包在ffmpeg中代表已经编码好的一…

FPGA实验3:D触发器设计

一、实验目的及要求 熟悉Quartus II 的 VHDL 文本设计简单时序电路的方法&#xff1b; 掌握时序电路的描述方法、波形仿真和测试&#xff0c;特别是时钟信号的特性。 二、实验原理 运用Quartus II 集成环境下的VHDL文本设计方法设计简单时序电路——D触发器&#xff0c;依据…

ABC363 题解

ABC363 题解 A - Piling Up (模拟) 题意&#xff1a; 输入一个数字&#xff0c;数字介于 1 1 1- 99 99 99显示了一次^, 100 100 100- 199 199 199显示了^两次…增加显示的所需的最小的评分增幅 分析&#xff1a; 算比次数字大且为 100 100 100的倍数的最小值减此数字 代码…

三相PWM整流器滞环电流控制仿真matlab simulink

1、内容简介 略 88-可以交流、咨询、答疑 2、内容说明 略 三相&#xff30;&#xff37;&#xff2d;整流器已广泛应用工业与电气控制领域电流控制技术决定着三相&#xff30;&#xff37;&#xff2d;整流器系统的控制性能。综合比 较了各种电流控制方法应用较多的滞环比较…

C++ 类和对象 构造函数(下)

一 初始化列表&#xff1a; 1.1 构造函数体赋值&#xff1a; 在C中&#xff0c;构造函数用于创建对象并赋予其初始值。通常&#xff0c;我们可以在构造函数体内对成员变量进行赋值&#xff1a; class Date { public:Date(int year, int month, int day) {_year year;_mont…

golang 解压带密码的zip包

目录 Zip文件详解ZIP 文件格式主要特性常用算法Zip格式结构图总览Zip文件结构详解数据区本地文件头文件数据文件描述 中央目录记录区&#xff08;核心目录记录区 &#xff09;中央目录记录尾部区 压缩包解压过程方式1 通过解析中央目录区来解压方式2 通过读取本地文件头来解压两…

Java 环境配置——Java 语言的安装、配置、编译与运行

引言 Java 作为全球最广泛使用的编程语言之一&#xff0c;其强大的跨平台特性和丰富的生态系统&#xff0c;使其在企业级应用、移动开发、大数据处理等领域具有重要地位。正确配置 Java 开发环境是每一个 Java 开发者的必备技能。本文将详细介绍如何在不同操作系统上安装、配置…

CentOS7中的yum命令不可用,网络不可达

前言 我也搜了大量的文章&#xff0c;基本上都是 输入 vi /etc/sysconfig/network-scripts/ifcfg-ens33 (这个ens33 是上面图片对应的以太网卡的名称&#xff0c;有的可能是ifcfg-eth0) 将 ONBOOTno 改为 ONBOOTyes以及其他方法&#xff0c;但是都没用。 解决 具体原因我也…

Wpf和Winform使用devpress控件库导出Excel并调整报表样式

Wpf和Winform使用devpress控件库导出Excel并调整报表样式 背景 客户需求经常需要出各种报表&#xff0c;部分客户对报表的样式有要求。包括颜色、字体、分页等等。 代码 使用Datagridview导出excel调整样式 DevExpress.XtraGrid.Views.Grid.GridView gdv #region GridView…

2024“钉耙编程”杭电多校1006 序列立方(思维+前缀和优化dp)

来源 题目 Problem Description 给定长度为 N 的序列 a。 一个序列有很多个子序列&#xff0c;每个子序列在序列中出现了若干次。 小马想请你输出序列 a 每个非空子序列出现次数的立方值的和&#xff0c;答案对 998244353 取模。 你可以通过样例解释来辅助理解题意。 Input 第…

[言简意赅] Matlab生成FPGA端rom初始化文件.coe

&#x1f38e;Matlab生成FPGA端rom初始化文件.coe 本文主打言简意赅。 函数源码 function gencoeInitialROM(width, depth, signal, filepath)% gencoeInitialROM - 生成 Xilinx ROM 初始化格式的 COE 文件%% 输入参数:% width - ROM 数据位宽% depth - ROM 数据深度% s…

heic文件怎么转换成jpg?上百份文件转换3秒就能搞定(办公必备)

heic和jpg是两种不同的图片格式&#xff0c;平时整理图片素材时&#xff0c;如果需要将heic转为jpg格式&#xff0c;那么可以使用相关的heic图片转换工具。 ​ 为什么要将heic文件转换成jpg&#xff1f;虽然HEIC格式具有很多优点&#xff0c;但是目前并不是所有设备和应用程序…

好玩模拟游戏推荐:缺氧:眼冒金星 单机游戏分享

《缺氧》 是一款太空殖民模拟游戏。 在外太空岩深处&#xff0c;你手下的勤劳开拓者们需要熟练掌握科技&#xff0c;战胜新的陌生生命形式&#xff0c;以及利用难以置信的太空技术来生存。甚至&#xff0c;还有可能繁荣起来。 建立广阔的基地以及探索生存所需的资源&#xff1…

服务攻防_01数据库安全RedisCouchdbH2database

一、数据库-Redis-未授权RCE&CVE 1、未授权访问&#xff1a;CNVD-2015-07557 &#xff08;1&#xff09;漏洞描述 Redis默认情况下会绑定在6379端口 如果没有采取相关策略&#xff08;如添加防火墙规则阻止非信任来源IP访问&#xff09;&#xff0c;会将Redis暴露在公网…

设计模式(工厂模式,模板方法模式,单例模式)

单例模式&#xff1a; 确保一个类只有一个实例&#xff0c;并提供全局访问点。 单例模式举例&#xff1a; 配置信息类&#xff1a;用于存储应用程序的全局配置信息&#xff0c;例如数据库连接信息、日志配置等。 日志类&#xff1a;用于记录应用程序运行时的日志信息&#x…

HTML5实现好看的天气预报网站源码

文章目录 1.设计来源1.1 获取天气接口1.2 PC端页面设计1.3 手机端页面设计 2.效果和源码2.1 动态效果2.2 源代码 源码下载万套模板&#xff0c;程序开发&#xff0c;在线开发&#xff0c;在线沟通 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_4…

揭秘电子画册制作流程,打造独一无二的作品

在这个数字化的时代&#xff0c;电子画册已经成为了展示个人创意和品牌形象的重要工具。它不仅能够呈现出丰富多彩的内容&#xff0c;还能够实现互动性和传播性&#xff0c;吸引众多观众的目光。然而&#xff0c;许多人对于电子画册的制作流程仍然感到陌生。本文将揭秘电子画册…

企业VR展厅如何提升品牌形象,生动展示产品和企业文化?

一、提升产品展示效果 1、全方位展示产品细节 企业VR展厅可以通过3D建模和虚拟现实技术&#xff0c;将产品的每一个细节清晰地展示出来。客户可以全方位查看产品的外观、结构和功能。这种身临其境的体验远比传统的平面展示更加生动和详细。 细节展示&#xff1a;客户可以通过…

Ubuntu22 Qt6.6 ROS 环境搭建

Ubuntu22.04; Qt6.6; Qt Creator 13.01; ROS2 1. 安装 Qt ROS 插件 1.下载地址&#xff1a; https://github.com/ros-industrial/ros_qtc_plugin/releases 选择对应 Qt Creator 版本的安装包。 2. Qt Creator中&#xff0c;“Help - 关于插件”–>“install Plugin…