TRT4-trt-integrate - 1 YOLOV5导出、编译、推理

 模型导出

 修改Image的Input动态维度

首先可以看到这个模型导出的时候Input有三个维度都是动态,而我们之前说过只需要一个batch维度是动态,所以要在export的export onnx 进行修改,将

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['output'],dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)} if dynamic else None)

改为:

        torch.onnx.export(model, im, f, verbose=False, opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['output'],dynamic_axes={'images': {0: 'batch'},  # shape(1,3,640,640)'output': {0: 'batch'}  # shape(1,25200,85)} if dynamic else None)

修改完的已经变成了只有batch是动态维度 。

修改output

而且也可以看到,这里的output输出有四个tensor,其中三个都是fpn结构,80*80 , 40*40 , 20*20,这些我们在这里去掉,仅保留拼接后的结果。

将yolov5-6.0/models/yolo.py中的Class detect修改:

return x if self.training else (torch.cat(z, 1), x)

-->

return x if self.training else torch.cat(z, 1)

修改完毕的output仅保留拼接后的结果。

剪去多余节点:

之后发现这个onnx还是很丑啊

发现真正导致变丑的原因在于这些节点 比如Gather。所以下一步就是要干掉他。

这一步就是将:

      for i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

改为:

            x[i] = self.m[i](x[i])  # convbs, _, ny, nx = map(int,x[i].shape)  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

可以看到明显改善了不少。

调整reshape

 但是在reshape中可以看到中间维度是-1,而我们要的是batch是-1

 

 而bs是-1 , y.view还有个-1,这肯定是不行的,那么我们就要手动把这个计算出来,首先y的shape和x的shape一样,x的shape是bs*self.na*self.no*ny*nx,那么这里就是y.view(bs , self.na*nx*ny,self,no)

之后保存再次导出,可以看到已经变成batch的-1了

修改多余节点:

 但是发现还有比如expand这种节点,推断可能是由于数据跟踪引起的

    def _make_grid(self, nx=20, ny=20, i=0):d = self.anchors[i].deviceyv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()anchor_grid = (self.anchors[i].clone() * self.stride[i]) \.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()

所以在这里就没必要每个都保存起来,直接给一个常量值就可以。

anchor_grid = (self.anchors[i].clone() * self.stride[i]).view(1,-1,1,1,2)

然后将所有用到self.anchor_grid的部分都替换为anchor_grid。

这样看起来就变成很平整的样子了

 

刚刚的那一大堆就变成了1*3*1*1*2这样子的常量值,这个kind是Initializer,就是常量的这个意思。

CPP推理过程:

TRT:

 YOLO:

 可以看到置信度稍微有一些区别。

输入input作warpaffine:

因为我们的输入是一个确认了的输入:是640*640。所以要对图像做一个类似warpaffine。就是等比缩放剧中填充

 

///// letter boxauto image = cv::imread("car.jpg");// 通过双线性插值对图像进行resizefloat scale_x = input_width / (float)image.cols;float scale_y = input_height / (float)image.rows;float scale = std::min(scale_x, scale_y);float i2d[6], d2i[6];// resize图像,源图像和目标图像几何中心的对齐i2d[0] = scale;  i2d[1] = 0;  i2d[2] = (-scale * image.cols + input_width + scale  - 1) * 0.5;i2d[3] = 0;  i2d[4] = scale;  i2d[5] = (-scale * image.rows + input_height + scale - 1) * 0.5;cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);  // image to dst(network), 2x3 matrixcv::Mat m2x3_d2i(2, 3, CV_32F, d2i);  // dst to image, 2x3 matrixcv::invertAffineTransform(m2x3_i2d, m2x3_d2i);  // 计算一个反仿射变换
//为什么要计算逆矩阵:因为正矩阵是图像变成warpaffine的过程,逆变换是把框变回到图像尺度的过程cv::Mat input_image(input_height, input_width, CV_8UC3);cv::warpAffine(image, input_image, m2x3_i2d, input_image.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));  // 对图像做平移缩放旋转变换,可逆,填充全是常量值,114cv::imwrite("input-image.jpg", input_image);
//存储一下warpaffine效果int image_area = input_image.cols * input_image.rows;unsigned char* pimage = input_image.data;float* phost_b = input_data_host + image_area * 0;float* phost_g = input_data_host + image_area * 1;float* phost_r = input_data_host + image_area * 2;for(int i = 0; i < image_area; ++i, pimage += 3){// 注意这里的顺序rgb调换了*phost_r++ = pimage[0] / 255.0f;*phost_g++ = pimage[1] / 255.0f;*phost_b++ = pimage[2] / 255.0f;}///checkRuntime(cudaMemcpyAsync(input_data_device, input_data_host, input_numel * sizeof(float), cudaMemcpyHostToDevice, stream));

存储一下warpaffine效果

之后就是作推理:

// 3x3输入,对应3x3输出auto output_dims = engine->getBindingDimensions(1);int output_numbox = output_dims.d[1];int output_numprob = output_dims.d[2];int num_classes = output_numprob - 5;//类别数int output_numel = input_batch * output_numbox * output_numprob;float* output_data_host = nullptr;float* output_data_device = nullptr;checkRuntime(cudaMallocHost(&output_data_host, sizeof(float) * output_numel));checkRuntime(cudaMalloc(&output_data_device, sizeof(float) * output_numel));// 明确当前推理时,使用的数据输入大小auto input_dims = engine->getBindingDimensions(0);input_dims.d[0] = input_batch;execution_context->setBindingDimensions(0, input_dims);float* bindings[] = {input_data_device, output_data_device};bool success      = execution_context->enqueueV2((void**)bindings, stream, nullptr);checkRuntime(cudaMemcpyAsync(output_data_host, output_data_device, sizeof(float) * output_numel, cudaMemcpyDeviceToHost, stream));checkRuntime(cudaStreamSynchronize(stream));

这个结果就是我们之前YOLOV5的predict(https://blog.csdn.net/zhuangtu1999/article/details/131499750?spm=1001.2014.3001.5501)

但在这里根之前不太一样了:

vector<vector<float>> bboxes;float confidence_threshold = 0.25;float nms_threshold = 0.5;for(int i = 0; i < output_numbox; ++i){float* ptr = output_data_host + i * output_numprob;float objness = ptr[4];if(objness < confidence_threshold)continue;float* pclass = ptr + 5;int label     = std::max_element(pclass, pclass + num_classes) - pclass;float prob    = pclass[label];float confidence = prob * objness;if(confidence < confidence_threshold)continue;// 中心点、宽、高float cx     = ptr[0];float cy     = ptr[1];float width  = ptr[2];float height = ptr[3];// 预测框float left   = cx - width * 0.5;float top    = cy - height * 0.5;float right  = cx + width * 0.5;float bottom = cy + height * 0.5;// 对应图上的位置float image_base_left   = d2i[0] * left   + d2i[2];float image_base_right  = d2i[0] * right  + d2i[2];float image_base_top    = d2i[0] * top    + d2i[5];float image_base_bottom = d2i[0] * bottom + d2i[5];bboxes.push_back({image_base_left, image_base_top, image_base_right, image_base_bottom, (float)label, confidence});}printf("decoded bboxes.size = %d\n", bboxes.size());

这里的预测框,left,top等等对应的是warpaffine之后的图片,但我们要做的是把他在原来的图片上加入回来,所以还要做一个反变换的过程。

这里也是我们值前提到过,咱们只有缩放和平移的时候,有效的参数只有三个:scale , dx , dy,这里对应的就是d2i[0] , d2i[2] , d2i[5]。

在之后就是nms:

// nms非极大抑制std::sort(bboxes.begin(), bboxes.end(), [](vector<float>& a, vector<float>& b){return a[5] > b[5];});std::vector<bool> remove_flags(bboxes.size());std::vector<vector<float>> box_result;box_result.reserve(bboxes.size());auto iou = [](const vector<float>& a, const vector<float>& b){float cross_left   = std::max(a[0], b[0]);float cross_top    = std::max(a[1], b[1]);float cross_right  = std::min(a[2], b[2]);float cross_bottom = std::min(a[3], b[3]);float cross_area = std::max(0.0f, cross_right - cross_left) * std::max(0.0f, cross_bottom - cross_top);float union_area = std::max(0.0f, a[2] - a[0]) * std::max(0.0f, a[3] - a[1]) + std::max(0.0f, b[2] - b[0]) * std::max(0.0f, b[3] - b[1]) - cross_area;if(cross_area == 0 || union_area == 0) return 0.0f;return cross_area / union_area;};for(int i = 0; i < bboxes.size(); ++i){if(remove_flags[i]) continue;auto& ibox = bboxes[i];box_result.emplace_back(ibox);for(int j = i + 1; j < bboxes.size(); ++j){if(remove_flags[j]) continue;auto& jbox = bboxes[j];if(ibox[4] == jbox[4]){// class matchedif(iou(ibox, jbox) >= nms_threshold)remove_flags[j] = true;}}}printf("box_result.size = %d\n", box_result.size());

通过cv::rectangle画框:

    for(int i = 0; i < box_result.size(); ++i){auto& ibox = box_result[i];float left = ibox[0];float top = ibox[1];float right = ibox[2];float bottom = ibox[3];int class_label = ibox[4];float confidence = ibox[5];cv::Scalar color;tie(color[0], color[1], color[2]) = random_color(class_label);//通过标签随机选择颜色cv::rectangle(image, cv::Point(left, top), cv::Point(right, bottom), color, 3);auto name      = cocolabels[class_label];auto caption   = cv::format("%s %.2f", name, confidence);int text_width = cv::getTextSize(caption, 0, 1, 2, nullptr).width + 10;cv::rectangle(image, cv::Point(left-3, top-33), cv::Point(left + text_width, top), color, -1);cv::putText(image, caption, cv::Point(left, top-5), 0, 1, cv::Scalar::all(0), 2, 16);}cv::imwrite("image-draw.jpg", image);checkRuntime(cudaStreamDestroy(stream));checkRuntime(cudaFreeHost(input_data_host));checkRuntime(cudaFreeHost(output_data_host));checkRuntime(cudaFree(input_data_device));checkRuntime(cudaFree(output_data_device));
}

总结:

在这次过程中,warpaffine(预处理)和后处理过程都可以使用我们之前的核函数去处理,这一部分打包到GPU上的话性能会变得更高。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/2721.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为云子网路由表作用及价值

子网路由表 子网路由表作用云专线、VPN的配置与子网路由表强关联&#xff0c;本质是在相应的子网路由表中添加了一条路由Nat路由表问题地址变更问题snat和dnat 子网路由表作用 子网内部作为一个二层网络&#xff0c;通过mac地址互通&#xff0c;不通过路由互通。跨子网&#x…

实时网络更改检测

未经授权的配置更改可能会对业务连续性造成严重破坏&#xff0c;这就是为什么使用实时更改检测来检测和跟踪更改是网络管理员的一项关键任务。尽管可以手动跟踪更改&#xff0c;但此方法往往非常耗时&#xff0c;并且通常会导致人为错误&#xff0c;例如在跟踪时错过关键网络设…

企业需要一个数字体验平台(DXP)吗?

数字体验平台是一个软件框架&#xff0c;通过与不同的业务系统喝解决方案集成&#xff0c;帮助企业和机构建立、管理和优化跨渠道的数字体验。帮助企业实现跨网站、电子邮件、移动应用、社交平台、电子商务站点、物联网设备、数字标牌、POS系统等传播内容&#xff0c;除了为其中…

文心一言 VS 讯飞星火 VS chatgpt (58)-- 算法导论6.4 2题

文心一言 VS 讯飞星火 VS chatgpt &#xff08;58&#xff09;-- 算法导论6.4 2题 二、试分析在使用下列循环不变量时&#xff0c;HEAPSORT 的正确性&#xff1a;在算法的第 2~5行 for 循环每次迭代开始时&#xff0c;子数组 A[1…i]是一个包含了数组A[1…n]中第i小元素的最大…

如果微信消息显示“已读”的话......

近日&#xff0c;一则 #如果微信显示已读的话# 话题冲上了微博热搜榜单。 “已读”是很多社交软件拥有的功能&#xff0c;如果对方接收并查看了消息&#xff0c;就会在消息上显示“已读”&#xff0c;但目前微信还没有推出这项功能。 对于“已读”功能&#xff0c;不少网友纷纷…

自动化用例编写思路 (使用pytest编写一个测试脚本)

目录 一&#xff0c;明确测试对象 二&#xff0c;编写测试用例 构造请求数据 封装测试代码 断言设置 三&#xff0c;执行脚本获取测试结果 四&#xff0c;总结 经过之前的学习铺垫&#xff0c;我们尝试着利用pytest框架编写一条接口自动化测试用例&#xff0c;来厘清接口…

【CNN记录】pytorch中BatchNorm2d

torch.nn.BatchNorm2d(num_features, eps1e-05, momentum0.1, affineTrue, track_running_statsTrue, deviceNone, dtypeNone) 功能&#xff1a;对输入的四维数组进行批量标准化处理&#xff08;归一化&#xff09; 计算公式如下&#xff1a; 对于所有的batch中样本的同一个ch…

商城-学习整理-基础-环境搭建(二)

目录 一、环境搭建1、安装linux虚拟机1&#xff09;下载&安装 VirtualBox https://www.virtualbox.org/&#xff0c;要开启 CPU 虚拟化2&#xff09;虚拟机的网络设置3&#xff09;虚拟机允许使用账号密码登录4&#xff09;VirtualBox冲突5&#xff09;修改 linux 的 yum 源…

PyCharm 常用快捷键

目录 1、代码编辑快捷键 2、搜索/替换快捷键 3、代码运行快捷键 4、代码调试快捷键 5、应用搜索快捷键 6、代码重构快捷键 7、动态模块快捷键 8、导航快捷键 9、通用快捷键 1、代码编辑快捷键 序号快捷键作用1CTRLALTSPACE快速导入任意类2CTRLSHIFTENTER代码补全3SHI…

$.getScript()方法获取js文件

通过$.getScript(‘xxxx.js’)获取xxxx.js文件&#xff0c;这时的ajax是一个get请求的状态&#xff0c;如果进行了入参data的赋值那么他就会跟在url后面,同理获取json文件&#xff0c;css文件。 一开始没想起这茬。。。

曲师大2023大一新生排位赛-B.Sort题解

题目描述 插入排序是一种非常常见且简单的排序算法。王同学是一名大一的新生&#xff0c;今天许师哥刚刚在上课的时候讲了插入排序算法。 假设比较两个元素的时间为 &#xff0c;则插入排序可以以 的时间复杂度完成长度为 n&#xfffd; 的数组的排序。不妨假设这 n 个数字分…

如何在PADS Logic中查找器件

PADS Logic提供类似于Windows的查找功能&#xff0c;可以进行器件的查找。 &#xff08;1&#xff09;在Logic设计界面中&#xff0c;将菜单显示中的“选择工具栏”进行打开&#xff0c;如图1所示&#xff0c;会弹出对应的“选择工具栏”的分栏菜单选项&#xff0c;如图2所示。…

IDE /字符串 /字符编码与文本文件(如cpp源代码文件)

文章目录 概述文本编辑器如何识别文件的编码格式优先推测使用了UTF-8编码&#xff1f;字符编码的BOM字节序标记重分析各文本编辑器下的测试效果Qt Creator的文本编辑器系统记事本VS的文本编辑器Notepad 编译器与代码文件的字符编码ANSI编码其他 概述 前期在整理 《IDE/VS项目属…

大数据存储架构详解:数据仓库、数据集市、数据湖、数据网格、湖仓一体

前言 本文隶属于专栏《大数据理论体系》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见大数据理论体系 思维导图 数据仓库 数据仓库是一个面向主题的&…

如何提升环境、生态、水文、土地、土壤、农业、大气等领域的数据分析能力

专题一、空间数据获取与制图 1.1 软件安装与应用讲解 1.2 空间数据介绍 1.3海量空间数据下载 1.4 ArcGIS软件快速入门 1.5 Geodatabase地理数据库 专题二、ArcGIS专题地图制作 2.1专题地图制作规范 2.2 空间数据的准备与处理 2.3 空间数据可视化&#xff1a;地图符号与…

酷开科技以内容为核心打造OTT大屏营销投放新体系

如何打造“因地制宜”的营销策略&#xff0c;围绕内容场景&#xff0c;搭建更具效能的OTT大屏营销投放体系&#xff1f;是一个值得思考的问题。 酷开科技OTT大屏营销&#xff0c; 以营销内容为核心、通过更加立体化的沟通模式&#xff0c;创新性整合和打通多元资源&#xff0c…

数据结构--图的存储邻接表法

数据结构–图的存储邻接表法 邻接矩阵&#xff1a; 数组实现的顺序存储&#xff0c;空间复杂度高&#xff0c;不适合存储稀疏图 邻接表&#xff1a; 顺序链式存储 邻接表法&#xff08;顺序链式存储&#xff09; //边/弧 typedef struct ArcNode {int adjvex; //边/弧指向哪个…

echarts 单数据,平滑曲线柱状图显示

var myChart echarts.init(document.getElementById(main)); let namelist [23/01, 23/02, 23/03, 23/04, 23/05, 23/06, YTD] let planList [10.9, 22.6, 15.6, 14.4, 12.0, 12.3, 14.6] let target 14 // 指定图表的配置项和数据 var option { tooltip: { },//提示语 xA…

wampserver的mysql8.0版本在my.ini文件中加入skip_grant_tables无效等一系列问题。

背景&#xff1a;安装了新的wampserver之后&#xff0c;php版本mysql8.0.31&#xff0c;想打开phpadmin可视化管理页面&#xff0c;后来忘记密码了&#xff0c;报错&#xff1a;ERROR 1045 (28000): Access denied for user rootlocalhost (using password: No)&#xff0c;只能…

Linux搭建SVN环境(最新版)

最新版本号(svn-1.14) https://opensource.wandisco.com/centos/7 更新版本库 sudo tee /etc/yum.repos.d/wandisco-svn.repo <<-EOF [WandiscoSVN] nameWandisco SVN Repo baseurlhttp://opensource.wandisco.com/centos/$releasever/svn-1.14/RPMS/$basearch/ enabled…