pytorch学习(十二)c++调用minist训练的onnx模型

在实际使用过程中,使用python速度不够快,并且不太好嵌入到c++程序中,因此可以把pytorch训练的模型转成onnx模型,然后使用opencv进行调用。

所需要用到的库有:

opencv

1.完整的程序如下

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import re
from PIL import Imagecur_pwd_path = os.getcwd()def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_besttensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():tensor = tensor.to(device)print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':tensor = tensor.to(torch.device("cpu"))print("tensor is on", tensor.device)
if tensor.device == "cpu":tensor = tensor.to(torch.device("cuda:0"))print("tensor is on", tensor.device)trainning_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])test_model.eval()
test_model.to("cuda")transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)model = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))num_epoches =10
min_loss_val = 100000
Resume = Falsedef train():global min_loss_valstart_epoch = -1if Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()tensorboard_ind =0;for epoch in range(num_epoches):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()tensorboard_ind += 1train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试test_loss_value = 0model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)test_loss_value = test_lossprint("test result:",100 * test_correct,"%  avg loss:",test_loss)scheduler_1.step()#设置checkpointif epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)def test_singleFrame():model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg"img =Image.open(img_path)test_model = MinistNet()test_model = torch.load(model_path)test_model.to("cuda")transform=ToTensor()img = transform(img)img.to("cuda")result = test_model(img)val, index = torch.max(result)print(index)# Press the green button in the gutter to run the script.if __name__ == '__main__':train()#保存onnxmodel.cpu()model.eval()x= torch.randn(1,1,28,28)torch.onnx.export(model,x,"model.onnx")

2.训练并保存模型

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)

3.加载并测试模型

model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])test_model.eval()
test_model.to("cuda")transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)

结果如下:

按照第0,1来数数,tensor([1])刚好就是2.

4.保存onnx模型

if __name__ == '__main__':train()#保存onnxmodel.cpu()model.eval()x= torch.randn(1,1,28,28)torch.onnx.export(model,x,"model.onnx")

5.使用C++加opencv实现minist手写数字的识别

// test_onnm.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//#include<ostream>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn.hpp>#include <iostream>using namespace std;
using namespace cv;
using namespace dnn;int main()
{std::cout << "Hello World!\n";//cv::dnn::Net net = cv::dnn::readTensorFromONNX();cv::dnn::Net net = cv::dnn::readNetFromONNX("E:\\TOOLE\\slam_evo\\pythonProject\\model.onnx");if (net.empty()){std::cout << "加载onnx模型失败" << std::endl;return -1;}net.setPreferableBackend(DNN_BACKEND_OPENCV);net.setPreferableTarget(DNN_TARGET_CPU);cv::Mat img = cv::imread("E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg",cv::IMREAD_GRAYSCALE);if(img.cols != 28 || img.rows != 28){return -1;}cv::Mat blob;float scaleFactor = 1 / 255.0;blobFromImage(img, blob, scaleFactor, Size(), Scalar(), true, false, CV_32F);net.setInput(blob);cv::Mat predict = net.forward();for (int i = 0; i < predict.total(); i++){std::cout << predict.at<float>(i) << "  ";}std::cout << std::endl;double minVal, maxVal;Point minLoc, maxLoc;// 查找最大值和最小值及其位置minMaxLoc(predict, &minVal, &maxVal, &minLoc, &maxLoc);cout << maxVal << "    " << maxLoc.x<<"   "<< maxLoc.y << "\n";return 0;}

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47374.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零基础STM32单片机编程入门(十七)SPI总线详解及RC522-NFC刷卡模块实战含源码

文章目录 一.概要二.SPI总线基本概念1.SPI总线内部框图2.总体特征3.通讯时序 三.RC522介绍1.NFC基本介绍2.RC522模块基本特点3.RC522模块原理图4.RC522模块SPI通讯时序 四.RC522模块读卡实验五.CubeMX工程源代码下载六.小结 一.概要 SPI总线是由Motorola公司提出&#xff0c;是…

05_解封装和解码

1. 基本概念 容器就是一种文件格式&#xff0c;比如flv、mkv、mp4等。包含下面5种流以及文件头信息。 流是一种视频数据信息的传输方式&#xff0c;5种流&#xff1a;音频&#xff0c;视频&#xff0c;字幕&#xff0c;附件&#xff0c;数据。 包在ffmpeg中代表已经编码好的一…

FPGA实验3:D触发器设计

一、实验目的及要求 熟悉Quartus II 的 VHDL 文本设计简单时序电路的方法&#xff1b; 掌握时序电路的描述方法、波形仿真和测试&#xff0c;特别是时钟信号的特性。 二、实验原理 运用Quartus II 集成环境下的VHDL文本设计方法设计简单时序电路——D触发器&#xff0c;依据…

三相PWM整流器滞环电流控制仿真matlab simulink

1、内容简介 略 88-可以交流、咨询、答疑 2、内容说明 略 三相&#xff30;&#xff37;&#xff2d;整流器已广泛应用工业与电气控制领域电流控制技术决定着三相&#xff30;&#xff37;&#xff2d;整流器系统的控制性能。综合比 较了各种电流控制方法应用较多的滞环比较…

C++ 类和对象 构造函数(下)

一 初始化列表&#xff1a; 1.1 构造函数体赋值&#xff1a; 在C中&#xff0c;构造函数用于创建对象并赋予其初始值。通常&#xff0c;我们可以在构造函数体内对成员变量进行赋值&#xff1a; class Date { public:Date(int year, int month, int day) {_year year;_mont…

golang 解压带密码的zip包

目录 Zip文件详解ZIP 文件格式主要特性常用算法Zip格式结构图总览Zip文件结构详解数据区本地文件头文件数据文件描述 中央目录记录区&#xff08;核心目录记录区 &#xff09;中央目录记录尾部区 压缩包解压过程方式1 通过解析中央目录区来解压方式2 通过读取本地文件头来解压两…

[言简意赅] Matlab生成FPGA端rom初始化文件.coe

&#x1f38e;Matlab生成FPGA端rom初始化文件.coe 本文主打言简意赅。 函数源码 function gencoeInitialROM(width, depth, signal, filepath)% gencoeInitialROM - 生成 Xilinx ROM 初始化格式的 COE 文件%% 输入参数:% width - ROM 数据位宽% depth - ROM 数据深度% s…

heic文件怎么转换成jpg?上百份文件转换3秒就能搞定(办公必备)

heic和jpg是两种不同的图片格式&#xff0c;平时整理图片素材时&#xff0c;如果需要将heic转为jpg格式&#xff0c;那么可以使用相关的heic图片转换工具。 ​ 为什么要将heic文件转换成jpg&#xff1f;虽然HEIC格式具有很多优点&#xff0c;但是目前并不是所有设备和应用程序…

好玩模拟游戏推荐:缺氧:眼冒金星 单机游戏分享

《缺氧》 是一款太空殖民模拟游戏。 在外太空岩深处&#xff0c;你手下的勤劳开拓者们需要熟练掌握科技&#xff0c;战胜新的陌生生命形式&#xff0c;以及利用难以置信的太空技术来生存。甚至&#xff0c;还有可能繁荣起来。 建立广阔的基地以及探索生存所需的资源&#xff1…

服务攻防_01数据库安全RedisCouchdbH2database

一、数据库-Redis-未授权RCE&CVE 1、未授权访问&#xff1a;CNVD-2015-07557 &#xff08;1&#xff09;漏洞描述 Redis默认情况下会绑定在6379端口 如果没有采取相关策略&#xff08;如添加防火墙规则阻止非信任来源IP访问&#xff09;&#xff0c;会将Redis暴露在公网…

HTML5实现好看的天气预报网站源码

文章目录 1.设计来源1.1 获取天气接口1.2 PC端页面设计1.3 手机端页面设计 2.效果和源码2.1 动态效果2.2 源代码 源码下载万套模板&#xff0c;程序开发&#xff0c;在线开发&#xff0c;在线沟通 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_4…

揭秘电子画册制作流程,打造独一无二的作品

在这个数字化的时代&#xff0c;电子画册已经成为了展示个人创意和品牌形象的重要工具。它不仅能够呈现出丰富多彩的内容&#xff0c;还能够实现互动性和传播性&#xff0c;吸引众多观众的目光。然而&#xff0c;许多人对于电子画册的制作流程仍然感到陌生。本文将揭秘电子画册…

企业VR展厅如何提升品牌形象,生动展示产品和企业文化?

一、提升产品展示效果 1、全方位展示产品细节 企业VR展厅可以通过3D建模和虚拟现实技术&#xff0c;将产品的每一个细节清晰地展示出来。客户可以全方位查看产品的外观、结构和功能。这种身临其境的体验远比传统的平面展示更加生动和详细。 细节展示&#xff1a;客户可以通过…

Ubuntu22 Qt6.6 ROS 环境搭建

Ubuntu22.04; Qt6.6; Qt Creator 13.01; ROS2 1. 安装 Qt ROS 插件 1.下载地址&#xff1a; https://github.com/ros-industrial/ros_qtc_plugin/releases 选择对应 Qt Creator 版本的安装包。 2. Qt Creator中&#xff0c;“Help - 关于插件”–>“install Plugin…

一个模板实现的工厂的编译问题的解决。牵扯到重载、特化等

简介 在一个项目里&#xff0c;调用了第三封的库&#xff0c;这个库里面有个类用的很多&#xff0c;而且其构造函数至少有6个&#xff0c;并且个人感觉还不够多。根据实际使用&#xff0c;还得增加一些。 需求 1、增加构造函数&#xff0c;比如除了下面的&#xff0c;还增加…

Gateway源码分析:路由Route、断言Predicate、Filter

文章目录 源码总流程图说明GateWayAutoConfigurationDispatcherHandlergetHandler()handleRequestWith()RouteToRequestUrlFilterReactiveLoadBalancerClientFilterNettyRoutingFilter 补充知识适配器模式 详细流程图 源码总流程图 在线总流程图 说明 Gateway的版本使用的是…

01常见控件

文章目录 控件各种响应事件获取控件类型CButton/CheckBox&#xff08;多选&#xff09;/RadioButton&#xff08;单选&#xff09;EditControl&#xff08;文本编辑框&#xff09;/ ListBox&#xff08;列表文本框&#xff09;/ComboBox&#xff08;可下拉列表&#xff09;Prog…

【Ubuntu】Ubuntu系统镜像

清华镜像源 Index of /ubuntu-releases/ | 清华大学开源软件镜像站 | Tsinghua Open Source MirrorIndex of /ubuntu-releases/ | 清华大学开源软件镜像站&#xff0c;致力于为国内和校内用户提供高质量的开源软件镜像、Linux 镜像源服务&#xff0c;帮助用户更方便地获取开源软…

stm32学习:(寄存器2)GPIO总体说明

目录 GPIO的主要特点 GPIO的8种工作模式 GPIO电路结构 GPIO输出模式 输出流程 复用输出模式 GPIO输入模式 输入流程 模拟输入流程 GPIO相关的7个寄存器 GPIOx_CRL GPIOx_CRH GPIOx_IDR GPIOx_ODR GPIOx_BSRR GPIOx_BRR GPIOx_LCKR 实例 三个灯流水灯 main.…

C语言基础 9. 指针

C语言基础 9. 指针 文章目录 C语言基础 9. 指针9.1. &9.2. 指针9.3. 指针的使用9.4. 指针与数组9.5. 指针与const9.6. 指针运算9.7. 动态内存分配 9.1. & 运算符&: scanf(“%d”, &i);里的& 获得变量的地址, 它的操作数必须是变量 int i;printf(“%x”, &…