c++通过tensorRT调用模型进行推理

模型来源
算法工程师训练得到的onnx模型

c++对模型的转换
拿到onnx模型后,通过tensorRT将onnx模型转换为对应的engine模型,注意:训练用的tensorRT版本和c++调用的tensorRT版本必须一致。

如何转换:

  1. 算法工程师直接转换为.engine文件进行交付。
  2. 自己转换,进入tensorRT安装目录\bin目录下,将onnx模型拷贝到bin目录,地址栏中输入cmd回车弹出控制台窗口,然后输入转换命令,如:

trtexec --onnx=model.onnx --saveEngine=model.engine --workspace=1024 --optShapes=input:1x13x512x640 --fp16

然后回车,等待转换完成,完成后如图所示:
在这里插入图片描述
并且在bin目录下生成.engine模型文件。

c++对.engine模型文件的调用和推理
首先将tensorRT对模型的加载及推理进行封装,命名为CTensorRT.cpp,老样子贴代码:

//CTensorRT.cpp
class Logger : public nvinfer1::ILogger {void log(Severity severity, const char* msg) noexcept override {if (severity <= Severity::kWARNING)std::cout << msg << std::endl;}
};Logger logger;
class CtensorRT
{
public:CtensorRT() {}~CtensorRT() {}private:std::shared_ptr<nvinfer1::IExecutionContext> _context;std::shared_ptr<nvinfer1::ICudaEngine> _engine;nvinfer1::Dims _inputDims;nvinfer1::Dims _outputDims;
public:void cudaCheck(cudaError_t ret, std::ostream& err = std::cerr){if (ret != cudaSuccess){err << "Cuda failure: " << cudaGetErrorString(ret) << std::endl;abort();}}bool loadOnnxModel(const std::string& filepath){auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));if (!builder){return false;}const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));if (!network){return false;}auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());if (!config){return false;}auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));if (!parser){return false;}parser->parseFromFile(filepath.c_str(), static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));std::unique_ptr<IHostMemory> plan{ builder->buildSerializedNetwork(*network, *config) };if (!plan){return false;}std::unique_ptr<IRuntime> runtime{ createInferRuntime(logger) };if (!runtime){return false;}_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    // (1,3,752,752)else_outputDims = _engine->getBindingDimensions(i);}return true;}bool loadEngineModel(const std::string& filepath){std::ifstream file(filepath, std::ios::binary);if (!file.good()){return false;}std::vector<char> data;try{file.seekg(0, file.end);const auto size = file.tellg();file.seekg(0, file.beg);data.resize(size);file.read(data.data(), size);}catch (const std::exception& e){file.close();return false;}file.close();auto runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(data.data(), data.size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个// 为输入和输出创建空间for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    //得到输入结构else_outputDims = _engine->getBindingDimensions(i);//得到输出结构}return true;}void ONNX2TensorRT(const char* ONNX_file, std::string save_ngine){// 1.创建构建器的实例nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);// 2.创建网络定义uint32_t flag = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);nvinfer1::INetworkDefinition* network = builder->createNetworkV2(flag);// 3.创建一个 ONNX 解析器来填充网络nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);// 4.读取模型文件并处理任何错误parser->parseFromFile(ONNX_file, static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));for (int32_t i = 0; i < parser->getNbErrors(); ++i){std::cout << parser->getError(i)->desc() << std::endl;}// 5.创建一个构建配置,指定 TensorRT 应该如何优化模型nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();// 7.指定配置后,构建引擎nvinfer1::IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);// 8.保存TensorRT模型std::ofstream p(save_ngine, std::ios::binary);p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());// 9.序列化引擎包含权重的必要副本,因此不再需要解析器、网络定义、构建器配置和构建器,可以安全地删除delete parser;delete network;delete config;delete builder;// 10.将引擎保存到磁盘,并且可以删除它被序列化到的缓冲区delete serializedModel;}uint32_t getElementSize(nvinfer1::DataType t) noexcept{switch (t){case nvinfer1::DataType::kINT32: return 4;case nvinfer1::DataType::kFLOAT: return 4;case nvinfer1::DataType::kHALF: return 2;case nvinfer1::DataType::kBOOL:case nvinfer1::DataType::kINT8: return 1;}return 0;}int64_t volume(const nvinfer1::Dims& d){return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());}bool infer(unsigned char* input, int real_input_size, cv::Mat& out_mat){tensor_custom::BufferManager buffer(_engine);cudaStream_t stream;cudaStreamCreate(&stream); // 创建异步cuda流int binds = _engine->getNbBindings();for (int i = 0; i < binds; i++){if (_engine->bindingIsInput(i)){size_t input_size;float* host_buf = static_cast<float*>(buffer.getHostBufferData(i, input_size));memcpy(host_buf, input, real_input_size);break;}}// 将输入传递到GPUbuffer.copyInputToDeviceAsync(stream);// 异步执行bool status = _context->enqueueV2(buffer.getDeviceBindngs().data(), stream, nullptr);if (!status)return false;buffer.copyOutputToHostAsync(stream);for (int i = 0; i < binds; i++){if (!_engine->bindingIsInput(i)){size_t output_size;float* tmp_out = static_cast<float*>(buffer.getHostBufferData(i, output_size));//do your something herebreak;}}cudaStreamSynchronize(stream);cudaStreamDestroy(stream);return true;}
};

调用方式

int main()
{vector<int> dims = { 1,13,512,640 };vector<float> vall;for (int i=0;i<13;i++){string file = "D:\\xxx\\" + to_string(i) + ".png";cv::Mat mt = imread(file, IMREAD_GRAYSCALE);cv::resize(mt, mt, cv::Size(640,512));mt.convertTo(mt, CV_32F, 1.0 / 255);cv::Mat shape_xr = mt.reshape(1, mt.total() * mt.channels());std::vector<float> vec_xr = mt.isContinuous() ? shape_xr : shape_xr.clone();vall.insert(vall.end(), vec_xr.begin(), vec_xr.end());}cv::Mat mt_4d(4, &dims[0], CV_32F, vall.data());string engine_model_file = "model.engine";CtensorRT cTensor;if (cTensor.loadEngineModel(engine_model_file)){cv::Mat out_mat;if (!cTensor.infer(mt_4d.data, vall.size() * 4, out_mat))std::cout << "infer error!" << endl;elsecv::imshow("out", out_mat);}elsestd::cout << "load model file failed!" << endl;cv::waitKey(0);return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/75080.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2020年12月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:数组指定部分逆序重放 将一个数组中的前k项按逆序重新存放。例如,将数组8,6,5,4,1前3项逆序重放得到5,6,8,4,1。 时间限制:1000 内存限制:65536 输入 输入为两行: 第一行两个整数,以空格分隔,分别为数组元素的个数n(1 < n…

在Qt5中SQLite3的使用

一、SQLite简要介绍 什么是SQLite SQLite是一个进程内的库&#xff0c;实现了自给自足的、无服务器的、零配置的、事务性的 SQL 数据库引擎。它是一个零配置的数据库&#xff0c;这意味着与其他数据库不一样&#xff0c;您不需要在系统中配置。 就像其他数据库&#xff0c;S…

基于javaweb的CT图像管理系统(servlet+jsp)

系统简介 本项目采用eclipse工具开发&#xff0c;jspservletjquery技术编写&#xff0c;数据库采用的是mysql&#xff0c;navicat开发工具。 三个角色&#xff1a;管理员&#xff0c;普通用户&#xff0c;医生 模块简介 管理员&#xff1a; 1、登录 2、用户管理 3、医生管…

ARM DIY(十)LRADC 按键

前言 ARM SOC 有别于单片机 MCU 的一点就是&#xff0c;ARM SOC 的 GPIO 比较少&#xff0c;基本上引脚都有专用的功能&#xff0c;因为它很少去接矩阵键盘、众多继电器、众多 LED。 但有时 ARM SOC 又需要三五个按键&#xff0c;这时候 LRADC 就是一个不错的选择&#xff0c;…

C# OpenVino Yolov8 Detect 目标检测

效果 项目 代码 using OpenCvSharp; using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using static System.Net.Mime.MediaT…

python趣味编程-数独游戏

数独游戏是一个用Python编程语言编写的应用程序。该项目包含可以显示实际应用程序的基本功能。该项目可以让修读 IT 相关课程并希望开发简单应用程序的学生受益。这个Python 数独游戏是一个简单的项目,可用于学习tkinter库的实践。这个数独游戏可以提供Python编程的基本编码技…

黑马JVM总结(三)

&#xff08;1&#xff09;栈内存溢出 方法的递归调用&#xff0c;没有设置正确的结束条件&#xff0c;栈会有用完的一天&#xff0c;导致栈内存溢出 可以修改栈的大小&#xff1a; 再次运行&#xff1a;减少了次数 案例二&#xff1a; 两个类的循环应用问题&#xff0c;导致Js…

linux-进程-execl族函数

exec函数的作用&#xff1a; 我们用fork函数创建新进程后&#xff0c;经常会在新进程中调用exec函数去执行另外一个程序。当进程调用exec函数时&#xff0c;该进程被完全替换为新程序。因为调用exec函数并不创建新进程&#xff0c;所以前后进程的ID并没有改变。 简单来说就是&…

如何使用聊天GPT自定义说明

推荐&#xff1a;使用 NSDT场景编辑器 快速搭建3D应用场景 OpenAI ChatGPT正在席卷全球。一周又一周&#xff0c;更新不断提高您可以使用这种最先进的语言模型做什么的标准。 在这里&#xff0c;我们深入研究了OpenAI最近在ChatGPT自定义指令上发布的公告。此功能最初以测试版…

第11篇:ESP32vscode_platformio_idf框架helloworld点亮LED

第1篇:Arduino与ESP32开发板的安装方法 第2篇:ESP32 helloword第一个程序示范点亮板载LED 第3篇:vscode搭建esp32 arduino开发环境 第4篇:vscodeplatformio搭建esp32 arduino开发环境 ​​​​​​第5篇:doit_esp32_devkit_v1使用pmw呼吸灯实验 第6篇:ESP32连接无源喇叭播…

智慧公厕是对智慧城市“神经末梢”的有效激活,公共厕所实现可感知、可视化、可管理、可控制

在当今科技迅速发展的时代&#xff0c;智慧城市已经成为人们关注的热点话题。作为城市基础设施的重要组成部分&#xff0c;公共厕所也逐渐融入到智慧城市的建设中&#xff0c;成为城市管理的焦点之一。智慧公厕作为智慧城市的“神经末梢”&#xff0c;通过可感知、可视化、可管…

【文末送书】Matlab科学计算

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术。关…

遥感图像应用:在低分辨率图像上实现洪水损害检测(迁移学习)

本文是上一篇关于“在低分辨率图像上实现洪水损害检测”的博客的延申。 代码来源&#xff1a;https://github.com/weining20000/Flooding-Damage-Detection-from-Post-Hurricane-Satellite-Imagery-Based-on-CNN/tree/master 数据储存地址&#xff1a;https://github.com/Jef…

CSS宽度问题

一、魔法 为 DOM 设置宽度有哪些方式呢&#xff1f;最常用的是配置width属性&#xff0c;width属性在配置时&#xff0c;也有多种方式&#xff1a; widthmin-widthmax-width 通常当配置了 width 时&#xff0c;不会再配置min-width max-width&#xff0c;如果将这三者混合使…

MySql 变量

1.系统变量 1.1 系统变量分类 变量由系统定义&#xff0c;不是用户定义&#xff0c;属于 服务器 层面。系统变量分为全局系统变量&#xff08;需要添加 global 关键字&#xff09;以及会话系统变量&#xff08;需要添加 session 关键字&#xff09;&#xff0c;有时也把全局系…

Web安全——Web安全漏洞与利用上篇(仅供学习)

SQL注入 一、SQL 注入漏洞1、与 mysql 注入的相关知识2、SQL 注入原理3、判断是否存在注入回显是指页面有数据信息返回id 1 and 114、三种 sql 注释符5、注入流程6、SQL 注入分类7、接受请求类型区分8、注入数据类型的区分9、SQL 注入常规利用思路&#xff1a;10、手工注入常规…

MySQL的权限管理与远程访问

MySQL的权限管理 1、授予权限 授权命令&#xff1a; grant 权限1,权限2,…权限n on 数据库名称.表名称 to 用户名用户地址 identified by ‘连接口令’; 该权限如果发现没有该用户&#xff0c;则会直接新建一个用户。 比如 grant select,insert,delete,drop on atguigudb.…

驱动开发,stm32mp157a开发板的led灯控制实验

1.实验目的 编写LED灯的驱动&#xff0c;在应用程序中编写控制LED灯亮灭的代码逻辑实现LED灯功能的控制&#xff1b; 2.LED灯相关寄存器分析 LED1->PE10 LED1亮灭&#xff1a; RCC寄存器[4]->1 0X50000A28 GPIOE_MODER[21:20]->01 (输出) 0X50006000 GPIOE_ODR[10]-&g…

文件操作(个人学习笔记黑马学习)

C中对文件操作需要包含头文件<fstream > 文件类型分为两种: 1.文本文件&#xff1a;文件以文本的ASCII码形式存储在计算机中 2.二进制文件&#xff1a;文件以文本的二进制形式存储在计算机中&#xff0c;用户一般不能直接读懂它们 操作文件的三大类: 1.ofstream: 写操作 …

SpringBoot项目启动时预加载

SpringBoot项目启动时预加载 Spring Boot是一种流行的Java开发框架&#xff0c;它提供了许多方便的功能来简化应用程序的开发和部署。其中一个常见的需求是在Spring Boot应用程序启动时预加载一些数据或执行一些初始化操作。 1. CommandLineRunner 和 ApplicationRunner Spri…