yolov5导出onnx模型问题

为了适配C++工程代码,我在导出onnx模型时,会把models/yolo.py里面的forward函数改成下面这样,

    #转模型def forward(self, x):z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))

也就是把后面类别得分中最大的那个计算出来赋值给idxs,

原来的yolov5输出是x y w h box_score  label1_confidence label2_confidence ....  labeln_confidence.

我改完之后,输出变成x y w h box_score idxs  label1_confidence label2_confidence ....  labeln_confidence.

然后之前我都是在转onnx之前手动的去改代码,然后转完模型再改回来因为train和detect也要用到这个yolo.py中的forward函数,但是后来某项目中,要实现一个自动训练、自动检测、自动转模型,这就不能我手动改了,所以我第一个方法是我复制一份yolo.py复制成yolo_onnx.py,然后export.py中from models.yolo_onnx import Detect,这种方法不可行,因为其他还有还有很多地方也是用的from models.yolo import Detect,最后用的方法如下:

首先在yolo.py中的Detect类中增加一个成员export

class Detect(nn.Module):stride = None  # strides computed during buildonnx_dynamic = False  # ONNX export parameterexport = False  #增加的成员......

然后我在export.py的run函数中给这个值赋值为true


@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truetrain=False,  # model.train() modeoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25  # TF.js NMS: confidence threshold):t = time.time()include = [x.lower() for x in include]  # to lowercaseformats = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in formats]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'model = attempt_load(weights, map_location=device, inplace=True, fuse=True)  # load FP32 modelnc, names = model.nc, model.names  # number of classes, class namesmodel.model[-1].export = True# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandopset = 12 if ('openvino' in include) else opset  # OpenVINO requires opset <= 12assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelif half:im, model = im.half(), model.half()  # to FP16model.train() if train else model.eval()  # training mode = no Detect() layer grid constructionfor k, m in model.named_modules():if isinstance(m, Conv):  # assign export-friendly activationsif isinstance(m.act, nn.SiLU):m.act = SiLU()elif isinstance(m, Detect):m.inplace = inplacem.onnx_dynamic = dynamicif hasattr(m, 'forward_export'):m.forward = m.forward_export  # assign custom forward (optional)for _ in range(2):y = model(im)  # dry runsshape = tuple(y[0].shape)  # model output shapeLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * 10  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:f[0] = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)if xml:  # OpenVINOf[3] = export_openvino(model, im, file)if coreml:_, f[4] = export_coreml(model, im, file)# TensorFlow Exportsif any((saved_model, pb, tflite, edgetpu, tfjs)):if int8 or edgetpu:  # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707check_requirements(('flatbuffers==1.12',))  # required before `import tensorflow`assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'model, f[5] = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres)  # keras modelif pb or tfjs:  # pb prerequisite to tfjsf[6] = export_pb(model, im, file)if tflite or edgetpu:f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)if edgetpu:f[8] = export_edgetpu(model, im, file)if tfjs:f[9] = export_tfjs(model, im, file)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python detect.py --weights {f[-1]}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"f"\nValidate:        python val.py --weights {f[-1]}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirs

然后修改yolo.py中的forward函数,增加分支判断

def forward(self, x):if self.export:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh z.append(y.view(bs, -1, self.no))                    else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whanchor, conf, prob = torch.split(y, [4, 1, self.nc], dim=4)# add a idx (label ids before prob)# oriidxs = torch.argmax(prob, dim=-1).unsqueeze(axis=-1).type(x[i].dtype)# new#idxs = torch.max(prob, dim=-1)[1].data.unsqueeze(axis=-1).type(x[i].dtype)y = torch.cat((xy, wh, conf, idxs, prob), -1)z.append(y.view(bs, -1, self.no + 1))return x if self.training else (torch.cat(z, 1))else:print("self.export===============",self.export)z = []  # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i])  # convbs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()if not self.training:  # inferenceif self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)y = x[i].sigmoid()if self.inplace:y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xyy[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # whelse:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i]  # xywh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # why = torch.cat((xy, wh, y[..., 4:]), -1)z.append(y.view(bs, -1, self.no))return x if self.training else (torch.cat(z, 1), x)

这样就可以实现train和export分别跑不同的代码了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/661396.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超级详细——手撕贪吃蛇小游戏!

目录 前言 1. Win32 API介绍 1.1 Win32 API 1.2 控制台程序 1.3 控制台屏幕上的坐标COORD 1.4 GetStdHandle 1.5 GetConsoleCursorInfo 1.6 CONSOLE_CURSOR_INFO 1.7 SetConsoleCursorInfo 1.8 SetConsoleCursorPosition 1.8 GetAsyncKeyState 2.贪吃蛇游戏设计 2.…

物联网浏览器(IoTBrowser)-Modbus协议集成和测试

Modbus协议在应用中一般用来与PLC或者其他硬件设备通讯&#xff0c;Modbus集成到IoTBrowser使用串口插件模式开发&#xff0c;不同的是采用命令函数&#xff0c;具体可以参考前面几篇文章。目前示例实现了Modbus-Rtu和Modbus-Tcp两种&#xff0c;通过js可以与Modbus进行通讯控制…

代码随想录算法训练营第三十六天| 435. 无重叠区间、763.划分字母区间、56. 合并区间

435. 无重叠区间 题目链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 解题思路&#xff1a;按左边界进行由小到大排序&#xff0c;然后前一个的右边界和后一个的左边界相比&#xff0c;看是否相交&#xff0c;如果相交&#xff0c;…

【EI会议征稿中|ACM出版】#先投稿,先送审#第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)​

#先投稿&#xff0c;先送审#ACM出版#第三届网络安全、人工智能与数字经济国际学术会议&#xff08;CSAIDE 2024&#xff09; 2024 3rd International Conference on Cyber Security, Artificial Intelligence and Digital Economy 2024年3月8日-10日 | 中国济南 会议官网&…

oracle数据库慢查询SQL

目录 场景&#xff1a; 环境&#xff1a; 慢SQL查询一&#xff1a; 问题一&#xff1a;办件列表查询慢 分析&#xff1a; 解决方法&#xff1a; 问题二&#xff1a;系统性卡顿 分析&#xff1a; 解决方法&#xff1a; 慢SQL查询二 扩展&#xff1a; 场景&#xff1a; 线…

CXO清单:低代码平台必备的16个基本功能:从需求到实现的全面指南

对于 CIO、CTO 和 CDO&#xff08;在此统称为 CXO&#xff09;来说&#xff0c;认识到快速变化的技术和竞争格局以及他们在组织中的角色变化至关重要。处理持续不断的软件开发请求、考虑不断变化的业务流程、提高客户和法规的透明度、提高企业数据安全性以及在短时间内扩展基础…

精酿啤酒:麦芽汁的煮沸与沸腾时间的影响

在啤酒酿造过程中&#xff0c;麦芽汁的煮沸与沸腾时间是关键的工艺参数&#xff0c;对啤酒的品质和口感具有显著影响。对于Fendi Club啤酒来说&#xff0c;合理控制煮沸与沸腾时间更是重要。 首先&#xff0c;麦芽汁的煮沸时间对啤酒的口感和稳定性有重要影响。煮沸时间过短&am…

如何使用宝塔面板搭建MySQL 5.5数据库并实现公网远程连接

文章目录 前言1.Mysql服务安装2.创建数据库3.安装cpolar3.2 创建HTTP隧道 4.远程连接5.固定TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址 前言 宝塔面板的简易操作性,使得运维难度降低,简化了Linux命令行进行繁琐的配置,下面简单几步,通过宝塔面板cp…

详解Keras3.0 Layer API: Base RNN layer

RNN layer keras.layers.RNN(cell,return_sequencesFalse,return_stateFalse,go_backwardsFalse,statefulFalse,unrollFalse,zero_output_for_maskFalse,**kwargs ) 参数说明 cell: 这是循环神经网络的单元类型&#xff0c;可以是LSTM、GRU等。它定义了循环神经网络的基本单…

linux系统上C程序的编译、运行及调试-gcc

gcc -o timer timer.c &#xff1a;生成可执行文件main&#xff0c;依托main.c,也可依托多个文件./timer :运行代码

【0254】深入分析Query Execution(二)

上一篇:【0253】深入分析Query Execution(一) 1. 转换(Transformation) 在下一阶段,可以对查询进行转换(重写, rewritten)。 PostgreSQL核心使用转换有几个目的。其中之一是将解析树中的视图名称替换为与该视图的基本查询相对应的子树。 使用转换的另一种情况是行级…

Skywalking的Trace Profiling 代码级性能剖析功能应用详解

代码级性能剖析 Skywalking 提供了Trace Profiling功能对具体出现问题的span进行代码级性能剖析。 代码级性能剖析就是利用方法栈快照&#xff0c;并对方法执行情况进行分析和汇总。并结合有限的分布式追踪 span 上下文&#xff0c;对代码执行速度进行估算。性能剖析激活时&a…

[C#][opencvsharp]winform实现自定义卷积核锐化和USM锐化

【锐化介绍】 图像锐化(image sharpening)是补偿图像的轮廓&#xff0c;增强图像的边缘及灰度跳变的部分&#xff0c;使图像变得清晰&#xff0c;分为空间域处理和频域处理两类。图像锐化是为了突出图像上地物的边缘、轮廓&#xff0c;或某些线性目标要素的特征。这种滤波方法…

【Python_PySide6学习笔记(三十三)】文本编辑框QTextEdit添加图片

文本编辑框QTextEdit添加图片 文本编辑框QTextEdit添加图片前言一、创建 QTextEdit 对象二、通过 QImage 加载图片,并调整图片的大小及比例三、创建 QTextCursor 对象四、通过QTextCursor 对象的 insertImage() 将图片插入到 QTextEdit 中五、完整代码及实现效果文本编辑框QTe…

详细分析SpringSecurity中的@PreAuthorize注解

目录 1. 基本知识2. 使用方式2.1 配置类2.2 直接使用 1. 基本知识 在Java中&#xff0c;PreAuthorize 是Spring Security框架中的一个注解&#xff0c;用于在方法调用之前对用户的权限进行验证。 允许在方法级别定义访问控制规则&#xff0c;确保只有满足指定条件的用户才能调…

boost asio对于epoll关闭套接字顺序

其方法定义在 boost::system::error_code reactive_socket_service_base::close(reactive_socket_service_base::base_implementation_type& impl,boost::system::error_code& ec) {if (is_open(impl)){BOOST_ASIO_HANDLER_OPERATION(("socket", &impl,…

HarmonyOS ArkUI基础学习01

以下涉及的项目源码地址&#xff1a; https://gitee.com/jiangqianghua/harmony-test 更多学习资源资源点我获取 1. 一些常用组件方法 加载resource/base/element/string.json资源 Text($r("app.string.Username_label"))设置颜色 Color.red“#ff00ff”读取资源文…

Java基础 集合(二)List详解

目录 简介 数组与集合的区别如下&#xff1a; 介绍 AbstractList 和 AbstractSequentialList Vector 替代方案 Stack ArrayList LinkedList 前言-与正文无关 生活远不止眼前的苦劳与奔波&#xff0c;它还充满了无数值得我们去体验和珍惜的美好事物。在这个快节奏的世界…

nodejs+vue+ElementUi家庭美食菜谱分享网站_in9c2

&#xff08;设计制作有一定的安全性&#xff1b;数据库方面主要采用的是MySQL来进行开发&#xff0c;其特点是稳定性好&#xff0c;数据库存储容量大&#xff0c;处理能力快等优势&#xff1b;服务器采用的是Tomcat服务&#xff0c;能够提供稳固的运行平台&#xff0c;确保系统…

JavaSE-项目小结-IP归属地查询(本地IP地址库)

一、项目介绍 1. 背景 IP地址是网络通信中的重要标识&#xff0c;通过分析IP地址的归属地信息&#xff0c;可以帮助我们了解访问来源、用户行为和网络安全等关键信息。例如应用于网站访问日志分析&#xff1a;通过分析访问日志中的IP地址&#xff0c;了解网站访问者的地理位置分…