tensorRT C++使用pt转engine模型进行推理

目录

  • 1. 前言
  • 2. 模型转换
  • 3. 修改Binding
  • 4. 修改后处理

1. 前言

本文不讲tensorRT的推理流程,因为这种文章很多,这里着重讲从标准yolov5的tensort推理代码(模型转pt->wts->engine)改造成TPH-yolov5(pt->onnx->engine)的过程。

2. 模型转换

请查看上一篇文章https://blog.csdn.net/wyw0000/article/details/139737473?spm=1001.2014.3001.5502

3. 修改Binding

如果不修改Binding,会报下图中的错误。
在这里插入图片描述
该问题是由于Binding有多个,而代码中只申请了input和output,那么如何查看engine模型有几个Bingding呢?代码如下:

int get_model_info(const string& model_path) {// 创建 loggerLogger gLogger;// 从文件中读取 enginestd::ifstream engineFile(model_path, std::ios::binary);if (!engineFile) {std::cerr << "Failed to open engine file." << std::endl;return -1;}engineFile.seekg(0, engineFile.end);long int fsize = engineFile.tellg();engineFile.seekg(0, engineFile.beg);std::vector<char> engineData(fsize);engineFile.read(engineData.data(), fsize);if (!engineFile) {std::cerr << "Failed to read engine file." << std::endl;return -1;}// 反序列化 engineauto runtime = nvinfer1::createInferRuntime(gLogger);auto engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);// 获取并打印输入和输出绑定信息for (int i = 0; i < engine->getNbBindings(); ++i) {nvinfer1::Dims dims = engine->getBindingDimensions(i);nvinfer1::DataType type = engine->getBindingDataType(i);std::cout << "Binding " << i << " (" << engine->getBindingName(i) << "):" << std::endl;std::cout << "  Type: " << (int)type << std::endl;std::cout << "  Dimensions: ";for (int j = 0; j < dims.nbDims; ++j) {std::cout << (j ? "x" : "") << dims.d[j];}std::cout << std::endl;std::cout << "  Is Input: " << (engine->bindingIsInput(i) ? "Yes" : "No") << std::endl;}// 清理资源engine->destroy();runtime->destroy();return 0;
}

下图是我的tph-yolov5的Binding,可以看到有5个Binding,因此在doInference推理之前,要给5个Binding都申请空间,同时要注意获取BindingIndex时,名称和dimension与查询出来的对应。
在这里插入图片描述

//for tph-yolov5int Sigmoid_921_index = trt->engine->getBindingIndex("onnx::Sigmoid_921");int Sigmoid_1183_index = trt->engine->getBindingIndex("onnx::Sigmoid_1183");int Sigmoid_1367_index = trt->engine->getBindingIndex("onnx::Sigmoid_1367");CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_921_index], BATCH_SIZE * 3 * 192 * 192 * 7 * sizeof(float)));CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1183_index], BATCH_SIZE * 3 * 96 * 96 * 7 * sizeof(float)));CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1367_index], BATCH_SIZE * 3 * 48 * 48 * 7 * sizeof(float)));trt->data = new float[BATCH_SIZE * 3 * INPUT_H * INPUT_W];trt->prob = new float[BATCH_SIZE * OUTPUT_SIZE];trt->inputIndex = trt->engine->getBindingIndex(INPUT_BLOB_NAME);trt->outputIndex = trt->engine->getBindingIndex(OUTPUT_BLOB_NAME);

还有推理的部分也要做修改,原来只有input和output两个Binding时,那么输出是buffers[1],而目前是有5个Binding那么输出就变成了buffers[4]

void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output, int batchSize) {// infer on the batch asynchronously, and DMA output back to hostcontext.enqueueV2(buffers, stream, nullptr);//CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));CUDA_CHECK(cudaMemcpyAsync(output, buffers[4], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));cudaStreamSynchronize(stream);
}

4. 修改后处理

之前的yolov5推理代码是将pt模型转为wts再转为engine的,输出维度只有一维,而TPH输出维度为145152*7,因此要对原来的后处理代码进行修改。

struct BoundingBox {//bbox[0],bbox[1],bbox[2],bbox[3],conf, class_idfloat x1, y1, x2, y2, score, index;
};float iou(const BoundingBox&  box1, const BoundingBox& box2) {float max_x = max(box1.x1, box2.x1);  // 找出左上角坐标哪个大float min_x = min(box1.x2, box2.x2);  // 找出右上角坐标哪个小float max_y = max(box1.y1, box2.y1);float min_y = min(box1.y2, box2.y2);if (min_x <= max_x || min_y <= max_y) // 如果没有重叠return 0;float over_area = (min_x - max_x) * (min_y - max_y);  // 计算重叠面积float area_a = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);float area_b = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);float iou = over_area / (area_a + area_b - over_area);return iou;
}std::vector<BoundingBox> nonMaximumSuppression(std::vector<std::vector<float>>& boxes, float overlapThreshold) {std::vector<BoundingBox> convertedBoxes;// 将数据转换为BoundingBox结构体for (const auto&  box: boxes) {if (box.size() == 6) { // Assuming [x1, y1, x2, y2, score]BoundingBox bbox;bbox.x1 = box[0];bbox.y1 = box[1];bbox.x2 = box[2];bbox.y2 = box[3];bbox.score = box[4];bbox.index = box[5];convertedBoxes.push_back(bbox);}else {std::cerr << "Invalid box format!" << std::endl;}}// 对框按照分数降序排序std::sort(convertedBoxes.begin(), convertedBoxes.end(), [](const BoundingBox& a, const BoundingBox&  b) {return a.score > b.score;});// 非最大抑制std::vector<BoundingBox> result;std::vector<bool> isSuppressed(convertedBoxes.size(), false);for (size_t i = 0; i < convertedBoxes.size(); ++i) {if (!isSuppressed[i]) {result.push_back(convertedBoxes[i]);for (size_t j = i + 1; j < convertedBoxes.size(); ++j) {if (!isSuppressed[j]) {float overlap = iou(convertedBoxes[i], convertedBoxes[j]);if (overlap > overlapThreshold) {isSuppressed[j] = true;}}}}}
#if 0// 输出结果std::cout << "NMS Result:" << std::endl;for (const auto& box: result) {std::cout << "x1: " << box.x1 << ", y1: " << box.y1<< ", x2: " << box.x2 << ", y2: " << box.y2<< ", score: " << box.score << ",index:" << box.index << std::endl;}
#endif return result;
}void post_process(float *prob_model, float conf_thres, float overlapThreshold, std::vector<Yolo::Detection> & detResult)
{int cols = 7, rows = 145152;//  ========== 8. 获取推理结果 =========std::vector<std::vector<float>> prediction(rows, std::vector<float>(cols));int index = 0;for (int i = 0; i < rows; ++i) {for (int j = 0; j < cols; ++j) {prediction[i][j] = prob_model[index++];}}//  ========== 9. 大于conf_thres加入xc =========std::vector<std::vector<float>> xc;for (const auto& row : prediction) {if (row[4] > conf_thres) {xc.push_back(row);}}//  ========== 10. 置信度 = obj_conf * cls_conf =========//std::cout << xc[0].size() << std::endl;for (auto& row: xc) {for (int i = 5; i < xc[0].size(); i++) {row[i] *= row[4];}}// ========== 11. 切片取出xywh 转为xyxy=========std::vector<std::vector<float>> xywh;for (const auto& row: xc) {std::vector<float> sliced_row(row.begin(), row.begin() + 4);xywh.push_back(sliced_row);}std::vector<std::vector<float>> box(xywh.size(), std::vector<float>(4, 0.0));xywhtoxxyy(xywh, box);// ========== 12. 获取置信度最高的类别和索引=========std::size_t mi = xc[0].size();std::vector<float> conf(xc.size(), 0.0);std::vector<float> j(xc.size(), 0.0);for (std::size_t i = 0; i < xc.size(); ++i) {// 模拟切片操作 x[:, 5:mi]auto sliced_x = std::vector<float>(xc[i].begin() + 5, xc[i].begin() + mi);// 计算 maxauto max_it = std::max_element(sliced_x.begin(), sliced_x.end());// 获取 max 的索引std::size_t max_index = std::distance(sliced_x.begin(), max_it);// 将 max 的值和索引存储到相应的向量中conf[i] = *max_it;j[i] = max_index;  // 加上切片的起始索引}// ========== 13. concat x1, y1, x2, y2, score, index;======== =for (int i = 0; i < xc.size(); i++) {box[i].push_back(conf[i]);box[i].push_back(j[i]);}std::vector<std::vector<float>> output;for (int i = 0; i < xc.size(); i++) {output.push_back(box[i]); // 创建一个空的 float 向量并}// ==========14 应用非最大抑制 ==========std::vector<BoundingBox>  result = nonMaximumSuppression(output, overlapThreshold);for (const auto& r : result){Yolo::Detection det;det.bbox[0] = r.x1;det.bbox[1] = r.y1;det.bbox[2] = r.x2;det.bbox[3] = r.y2;det.conf = r.score;det.class_id = r.index;detResult.push_back(det);}}

代码参考:
https://blog.csdn.net/rooftopstars/article/details/136771496
https://blog.csdn.net/qq_73794703/article/details/132147879

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/32777.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何关闭软件开机自启,提升电脑开机速度?

如何关闭软件开机自启&#xff0c;提升电脑开机速度&#xff1f;大家知道&#xff0c;很多软件在安装时默认都会设置为开机自动启动。但是&#xff0c;有很多软件在我们开机之后并不是马上需要用到的&#xff0c;开机启动的软件过多会导致电脑开机变慢。那么&#xff0c;如何关…

【break】大头哥哥做题

【break】大头哥哥做题 时间限制: 1000 ms 内存限制: 65536 KB 【题目描述】 【参考代码】 #include <iostream> using namespace std; int main(){ int sum 0;//求和int day 0;//天数 while(1){int a;cin>>a;if(a-1){break;//结束当前循环 }sum sum a; …

HTTP基本概念介绍

HTTP概述 HTTP : 超文本传输协议&#xff0c;HTTP是浏览器端Web通信的基础。 一&#xff0c; 两种架构 B/S架构&#xff1a;Browser/Server&#xff0c;浏览器/服务器架构。 B: 浏览器&#xff0c;比如Firefox 、Google 、Internet; S: 服务器&#xff0c;Apache&#xff0c…

[stm32]温湿度采集与OLED显示

一、I2C总线协议 I2C&#xff08;Inter-integrated circuit &#xff09;是一种允许从不同的芯片或电路与不同的主芯片通信的协议。它仅用于短距离通信&#xff0c;是一种用于两个或多个设备之间进行数据传输的串行总线技术&#xff0c;它可以让你在微处理器、传感器、存储器、…

6月20日(周四)A股行情总结:A股险守3000点,恒生科技指数跌1.6%

A股三大股指走弱&#xff0c;科创板逆势上扬&#xff0c;半导体板块走强&#xff0c;多股20CM涨停。中芯国际港股涨超1%。恒生科技指数跌超1%。离岸人民币对美元汇率小幅走低&#xff0c;20日盘中最低跌至7.2874&#xff0c;创下2023年11月中旬以来的新低&#xff0c;随后收复部…

287 寻找重复数-类似于环形链表II

题目 给定一个包含 n 1 个整数的数组 nums &#xff0c;其数字都在 [1, n] 范围内&#xff08;包括 1 和 n&#xff09;&#xff0c;可知至少存在一个重复的整数。 假设 nums 只有 一个重复的整数 &#xff0c;返回 这个重复的数 。 你设计的解决方案必须 不修改 数组 nums…

理解堆排序

堆排序&#xff08;Heapsort&#xff09;是一种基于堆这种数据结构的排序算法&#xff0c;但在实际实现中&#xff0c;堆通常是用数组来表示的。这种方法充分利用了数组的特性&#xff0c;使得堆的操作更加高效。下面通过详细解释和举例说明来帮助理解这种排序方式。 堆的数组…

Linux应急响应——知攻善防应急靶场-Linux(1)

文章目录 查看history历史指令查看开机自启动项异常连接和端口异常进程定时任务异常服务日志分析账户排查总结 靶场出处是知攻善防 Linux应急响应靶机 1 前景需要&#xff1a; 小王急匆匆地找到小张&#xff0c;小王说"李哥&#xff0c;我dev服务器被黑了",快救救我&…

手持弹幕LED滚动字幕屏夜店表白手灯接机微信抖音小程序开源版开发

手持弹幕LED滚动字幕屏夜店表白手灯接机微信抖音小程序开源版开发 专业版 插件版 手持弹幕小程序通常提供多种功能&#xff0c;以便用户在不同的场合如夜店、表白、接机等使用。以下是一些常见的功能列表&#xff1a; 文本输入&#xff1a; 输入要显示的文字内容&#xff0c;…

强化学习算法复现记录

目录 1.多智能体强化学习MADDPG tensorflow2版本IMAC tensorflow2版本 2.单智能体强化学习DQN pytorch版本PPO pytorch版本 1.多智能体强化学习 MADDPG tensorflow2版本 文章链接&#xff1a;tensorflow2实现多智能体强化学习算法MADDPG IMAC tensorflow2版本 文章链接&…

如何利用AopContext.currentProxy()解决事务管理中的方法调用问题

在Spring应用开发中&#xff0c;使用AOP&#xff08;面向切面编程&#xff09;来管理事务是非常常见的做法。然而&#xff0c;在某些场景下&#xff0c;尤其是在同一个类的方法内部&#xff0c;一个非事务方法直接调用另一个带有事务注解的方法时&#xff0c;可能会遇到事务不生…

初中英语优秀作文分析-005How to Plan Our Life Wisely-如何明智地规划我们的生活

PDF格式公众号回复关键字:SHCZYF005 记忆树 1 The “double reduction policy” reduces the burden on students and offers us more spare time than before, but how to plan our life wisely? 翻译 “双减政策”减轻了学生的负担&#xff0c;给了我们比以前更多的业余…

Linux进程概念(二)

上期我们已经学习了进程的基础的内容&#xff0c;已经对进程的基本概念有了了解&#xff0c;知道了进程的组成&#xff0c; 本期我们将以操作为主进一步探讨进程的相关概念。 目录 查看进程 创建进程 查看进程 查看进程主要有两种方式。 ps ajx指令 在当前目录下有名为tes…

SpringBoot-注解@ImportResource引入自定义spring的配置xml文件和配置类

1、注解ImportResource 我们知道Spring的配置文件是可以有很多个的&#xff0c;我们在web.xml中如下配置就可以引入它们&#xff1a; SprongBoot默认已经给我们配置好了Spring&#xff0c;它的内部相当于已经有一个配置文件&#xff0c;那么我们想要添加新的配置文件怎么办&am…

SkyWalking 极简入门

1. 概述 1.1 概念 SkyWalking 是什么&#xff1f; FROM Apache SkyWalking 分布式系统的应用程序性能监视工具&#xff0c;专为微服务、云原生架构和基于容器&#xff08;Docker、K8s、Mesos&#xff09;架构而设计。 提供分布式追踪、服务网格遥测分析、度量聚合和可视化一体…

【CSS in Depth 2 精译】1.5 渐进式增强

文章目录 1.5 渐进式增强1.5.1 利用层叠规则实现渐进式增强1.5.2 渐进式增强的选择器1.5.3 利用 supports() 实现特性查询启用浏览器实验特性 1.5 渐进式增强 要用好 CSS 这样一门不断发展演进中的语言&#xff0c;其中一个重要的因素就是要与时俱进&#xff0c;及时了解哪些功…

AI 大模型企业应用实战(09)-LangChain的示例选择器

1 根据长度动态选择提示词示例组 1.1 案例 根据输入的提示词长度综合计算最终长度&#xff0c;智能截取或者添加提示词的示例。 from langchain.prompts import PromptTemplate from langchain.prompts import FewShotPromptTemplate from langchain.prompts.example_selecto…

PADS学习笔记

1.PADS设计PCB流程 封装库&#xff08;layout&#xff09;&#xff0c;原理图库&#xff08;logic&#xff09;的准备原件封装的匹配&#xff08;logic&#xff09;原理图的绘制&#xff08;logic&#xff09;导网表操作&#xff08;logic&#xff09;导入结构&#xff08;lay…

C++系列-String(一)

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” string是用于字符串&#xff0c;可以增删改查 首先&#xff0c;我们来看一下string的底层 接下来&#xff0c;我们来看一下string的常用接口有哪些&#xff1a; #define _CRT_S…

【机器学习】音乐大模型的深入探讨——当机器有了创意,是机遇还是灾难?

&#x1f440;国内外音乐大模型基本情况&#x1f440; ♥概述♥ ✈✈✈如FreeCompose、一术科技等&#xff0c;这些企业专注于开发人工智能驱动的语音、音效和音乐生成工具&#xff0c;致力于利用核心技术驱动文化产业升级。虽然具体公司未明确提及&#xff0c;但可以预见的是…