TensorRT加速推理入门-1:Pytorch转ONNX

这篇文章,用于记录将TransReID的pytorch模型转换为onnx的学习过程,期间参考和学习了许多大佬编写的博客,在参考文章这一章节中都已列出,非常感谢。

1. 在pytorch下使用ONNX主要步骤

1.1. 环境准备

安装onnxruntime包
安装教程可参考:
onnx模型预测环境安装笔记
onnxruntime配置
CPU版本:
直接pip安装

pip install onnxruntime

GPU版本:
先查看自己CUDA版本然后在下面的链接去找对应的onnxruntime的版本
CUDA版本的查询,可参考这个
onnxruntime版本查询
查询到对应版本,直接pip安装即可,例如

pip install onnxruntime-gpu==1.13.1

安装onnxsim包

pip install onnx-simplifier

1.2. 搭建 PyTorch 模型(TransReID)

def get_net(model_path,opt_=False):if opt_:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")#cfg.freeze()train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)else:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']state_dict = torch.load(model_path, map_location=torch.device('cpu'))model_state_dict=net.state_dict()for key in list(state_dict.keys()):if key[7:] in model_state_dict.keys():model_state_dict[key[7:]]=state_dict[key]net.load_state_dict(model_state_dict)return net

1.3. pytorch模型转换为 ONNX 模型

这个提供了静态转换(静态转换支持静态输入)和动态转换(动态转换支持动态输入)两个函数,可根据需要选择。

def convert_onnx_dynamic(model,save_path,simp=False):x = torch.randn(4, 3, 256,128)input_name = 'input'output_name = 'class'torch.onnx.export(model,x,save_path,input_names = [input_name],output_names = [output_name],dynamic_axes= {input_name: {0: 'B'},output_name: {0: 'B'}})if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def convert_onnx(model,save_path,batch=1,simp=False):input_names = ['input']output_names=['class']x = torch.randn(batch, 3, 256, 128)for para in model.parameters():para.requires_grad = False# model_script = torch.jit.script(model)# model_trace = torch.jit.trace(model, x)torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')

pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export,来看看该函数的参数和用法。

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数用法
model需要导出的pytorch模型
args模型的任意一组输入(模拟实际输入数据的大小,比如三通道的512*512大小的图片,就可以设置为torch.randn(1, 3 , 512, 512)
path输出的onnx模型的位置,例如yolov5.onnx
export_params输出模型是否可训练。default=True,表示导出trained model, 否则untrained。
verbose是否打印模型转换信息,default=None
opset_versiononnx算子集的版本
input_names模型的输入节点名称(自己定义的),如果不写,默认输出数字类型 的名称
output_name模型的输出节点名称(自己定义的), 如果不写,默认输出数字类型的名称
do_constant_folding是否使用常量折叠,默认即可。default=True。
dynamic_axes设置动态输入输出,用法:“输入输出名:[支持动态的维度”,如"支持动态的维度设置为[0, 2, 3]"则表示第0维,第2维,第3维支持动态输入输出。
模型的输入输出有时是可变的,如rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b, 3, h, w), 其中batch、height、width是可变的,但是chancel是固定三通道。格式如下:1)仅list(int)dynamic_axes={‘input’:[0, 2, 3], ‘output’:{0:‘batch’, 1:‘c’}} 2)仅dict<int, string> dynamic_axes={‘input’:{‘input’:{0:‘batch’, 2:‘height’, 3:‘width’}, ‘output’:{0:‘batch’, 1:‘c’}} 3)mixed dynamic_axes={‘input’:{0:‘batch’, 2:‘height’, 3:‘width’}, ‘output’:[0,1]}

注意onnx不支持结构中带有if语句的模型,如:

当我们在网络中嵌入一些if选择性的语句时,不好意思,模型不会考虑这些, 它只会记录下运行时走过的节点,不会根据if的实际情况来选择走哪条路, 所以势必会丢弃一部分节点,而丢弃哪些则是根据我们转模型时的输入来定的,一旦指定了,后面运行onnx模型都会如此。另一个问题就是,我们在代码中有一些循环或者迭代的操作时,要注意,尤其是我们的迭代次数是根据输入不同 会有变化时,也会因为这些操作导致后面的推理出现意外错误,正像前面说的,模型转换不喜欢不确定的东西,它会把这些变量dump成常量,所以会导致推理 错误。

对于实际部署的需求,很多时候pytorch是不满足的,所以需要转成其他模型格式来加快推理。常用的就是onnx,onnx天然支持很多框架模型的转换,如Pytorch,tf,darknet,caffe等。而pytorch也给我们提供了对应的接口,就是torch.onnx.export。下面具体到每一步。
原文来自:Windows下使用ONNX+pytorch记录

首先,环境和依赖:onnx包,cuda和cudnn,我用的版本号分别是1.7.0, 10.1, 7.5.4。
我们需要提供一个pytorch的模型,然后调用torch.onnx.export,同时还需要提供另外一些参数。我们一个个来分析,一是我们要给一个dummy input, 就是随便指定一个和我们实际输入时尺寸相同的一个随机数,是Tensor类型的,然后我们要指定转换的device,即是在gpu还是cpu。 然后我们要给一个input_names和output_names,这是绑定输入和输出,当然输入和输出可能不止一个,那就根据实际的输入和输出个数来给出name列表,
如果我们指定的输入和输出名和实际的网络结构不一致的话,onnx会自动给我们设置一个名字。一般是数字字符串。
输入和输出的绑定之后,我我们们可以看到还有一个参数叫做dynamic_axes,这是做什么的呢?哦,这是指定动态输入的,为了满足我们实际推理过程中,可能每张图片的分辨率不一样,所以允许我们给每个维度设置动态输入,这样是不是灵活多了?然后,设置完这些参数和输入,我们就可以开始转换模型了,如果不报错就是成功了,会在当前目录下生成一个.onnx文件。
原文来自: 一文掌握Pytorch-onnx-tensorrt模型转换

1.4 onnx-simplifier简化onnx模型

model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)

Pytorch转换为ONNX的完整代码pytorch_to_onnx.py

import json
import os
import onnx
import torch
import argparse
import torch.nn as nn
from onnxsim import simplify
from collections import OrderedDict
import torch.nn.functional as F# TransReID的模型构建需要的包
from model.make_model import *
from config import cfg
from datasets import make_dataloader os.environ['CUDA_VISIBLE_DEVICES'] = '1'def convert_onnx_dynamic(model,save_path,simp=False):x = torch.randn(4, 3, 256,128)input_name = 'input'output_name = 'class'torch.onnx.export(model,x,save_path,input_names = [input_name],output_names = [output_name],dynamic_axes= {input_name: {0: 'B'},output_name: {0: 'B'}})if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model,input_shapes={'input':(4,3,256,128)},dynamic_input_shape=True)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def convert_onnx(model,save_path,batch=1,simp=False):input_names = ['input']output_names=['class']x = torch.randn(batch, 3, 256, 128)for para in model.parameters():para.requires_grad = False# model_script = torch.jit.script(model)# model_trace = torch.jit.trace(model, x)torch.onnx.export(model, x, save_path,input_names =input_names,output_names=output_names, opset_version=12)if simp:onnx_model = onnx.load(save_path) model_simp, check = simplify(onnx_model)assert check, "Simplified ONNX model could not be validated"onnx.save(model_simp, save_path)print('simplify onnx done')def get_net(model_path,opt_=False):if opt_:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")#cfg.freeze()train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)else:cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']state_dict = torch.load(model_path, map_location=torch.device('cpu'))model_state_dict=net.state_dict()for key in list(state_dict.keys()):if key[7:] in model_state_dict.keys():model_state_dict[key[7:]]=state_dict[key]net.load_state_dict(model_state_dict)return netif __name__=="__main__":parser = argparse.ArgumentParser(description='torch to onnx describe.')parser.add_argument("--model_path",type = str,default="/home/TransReID-main/weights/vit_transreid_occ_duke.pth",help="torch weight path, default is MobileViT_Pytorch/weights-file/model_best.pth.tar.")parser.add_argument("--save_path",type=str,default="/home/TransReID-main/weights/vit_transreid_occ_duke_v2.onnx",help="save direction of onnx models,default is ./target/MobileViT.onnx.")parser.add_argument("--batch",type=int,default=1,help="batchsize of onnx models, default is 1.")parser.add_argument("--opt",default=False, action='store_true',help="model optmization , default is False.")parser.add_argument("--dynamic",default=False, action='store_true',help="export  dynamic onnx model , default is False.")args = parser.parse_args()
#   print(args)#net = get_net(args.model_path,opt_=args.opt)net = get_net(args.model_path)if args.dynamic:convert_onnx_dynamic(net,args.save_path,simp=True)else:with torch.no_grad():convert_onnx(net,args.save_path,simp=True,batch=args.batch)

1.5 查看onnx模型

当将pytorch模型保存为 ONNX 之后,可以使用一款名为 Netron 的软件打开 .onnx 文件,查看模型结构。

2. 参考文章

[1] Windows下使用ONNX+pytorch记录
[2] pytorch-onnx-tensorrt全链路简单教程(支持动态输入)
[3] PyTorch语义分割模型转ONNX以及对比转换后的效果(PyTorch2ONNX、Torch2ONNX、pth2onnx、pt2onnx、修改名称、转换、测试、加载ONNX、运行ONNX)
[4] ONNX系列一:ONNX的使用,从转化到推理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/611447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 多轴图,双轴图,指定哪几个数据在哪个轴上显示

这里使用的在这里导入&#xff0c; 秋云 ucharts echarts 高性能跨全端图表组件 - DCloud 插件市场 这里我封装成一个组件&#xff0c;自适应的&#xff0c;可以直接复制到自己的项目中 <template><qiun-data-charts type"mix":opts"opts":cha…

一网打尽!最佳新闻资讯App推荐,谁是你的首选?

注重个性化推荐&#xff0c;推荐&#xff1a;今日头条、一点资讯 注重传统新闻阅读体验&#xff0c;推荐&#xff1a;网易新闻、新浪新闻、搜狐新闻、东方头条 注重订阅推送阅读体验&#xff0c;推荐&#xff1a;红板报、ZAKER 注重评论和用户社区交互&#xff0c;推荐&…

[后端] 微服务的前世今生

微服务的前世今生 整体脉络: 单体 -> 垂直划分 -> SOA -> micro service 微服务 -> services mesh服务网格 -> future 文章目录 微服务的前世今生单一应用架构特征优点&#xff1a;缺点&#xff1a; 垂直应用架构特征优点缺点 SOA 面向服务架构特征优点缺点 微服…

2024年中国杭州|网络安全技能大赛(CTF)正式开启竞赛报名

前言 一、CTF简介 CTF&#xff08;Capture The Flag&#xff09;中文一般译作夺旗赛&#xff0c;在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会&#xff0c;以代替之前黑客们通过互相发起真实攻击进行技术比拼的…

c++ fork, execl 参数 logcat | grep

Linux进程编程(PS: exec族函数、system、popen函数)_linux popen函数会新建进程吗-CSDN博客 execvp函数详解_如何在C / C 中使用execvp&#xff08;&#xff09;函数-CSDN博客 C语言的多进程fork()、函数exec*()、system()与popen()函数_c语言 多进程-CSDN博客 Linux---fork…

#3. 全排列问题

【问题描述】 输出自然数 1 到 n 所有不重复的排列&#xff0c;即 n 的全排列&#xff0c;要求所产生的任一数字序列中不允许出现重复的数字。 【输入格式】 一个整数n(1≤n≤9) 【输出格式】 按照字典序的顺序输出由 1&#xff5e;n 组成的所有不重复的数字序列&#xff…

生成式人工智能市场规模、趋势和统计数据(2024-2026)

生成式人工智能市场规模、趋势和统计数据&#xff08;2024-2026&#xff09; 目录 生成式人工智能市场规模、趋势和统计数据&#xff08;2024-2026&#xff09;一、生成式人工智能行业亮点二、生成式人工智能市场规模三、生成式人工智能市场增长预测四、生成式人工智能采用统计…

Linux信号处理浅析

一、信号从发送到被处理经历的过程 1、常见概念 (1) 信号阻塞 阻塞&#xff0c;即被进程拉黑&#xff0c;信号被发送后&#xff0c;分为两种情况&#xff0c;一种是被阻塞了&#xff08;被拉黑了&#xff09;&#xff0c;一种是没有被阻塞。 (2) 信号未决 在信号被进程处理…

RT-Thread:SPI万能驱动 SFUD 驱动Flash W25Q64,通过 STM32CubeMX 配置 STM32 SPI 驱动

关键词&#xff1a;SFUD,FLASH,W25Q64&#xff0c;W25Q128&#xff0c;STM32F407 说明&#xff1a;RT-Thread 系统 使用 SPI万能驱动 SFUD 驱动 Flash W25Q64&#xff0c;通过 STM32CubeMX 配置 STM32 SPI 驱动。 提示&#xff1a;SFUD添加后的存储位置 1.打开RT-Thread Sett…

K8S的部署策略,重建更新和滚动更新

Deployment Strategies 部署战略 When it comes time to change the version of software implementing your service, a Kubernetes deployment supports two different rollout strategies: RecreateRollingUpdate 当需要更改实施服务的软件版本时&#xff0c;Kubernetes …

跟着我学Python进阶篇:02.面向对象(上)

往期文章 跟着我学Python基础篇&#xff1a;01.初露端倪 跟着我学Python基础篇&#xff1a;02.数字与字符串编程 跟着我学Python基础篇&#xff1a;03.选择结构 跟着我学Python基础篇&#xff1a;04.循环 跟着我学Python基础篇&#xff1a;05.函数 跟着我学Python基础篇&#…

力扣2696. 删除子串后的字符串最小长度

Problem: 2696. 删除子串后的字符串最小长度 文章目录 思路解题方法复杂度Code 思路 可以知道能够消除的只有AB 和CD 的者两种排列顺序方式&#xff0c;但是也许在发生一次消除后还会引发后续的消除可能性。 元素从前向后进行检测&#xff0c;如果是A或者C进行标记入栈&#xf…

Linux C/C++ 显示NIC流量统计信息

NIC流量统计信息是由操作系统维护的。当数据包通过NIC传输时&#xff0c;操作系统会更新相关的计数器。这些计数器记录了数据包的发送和接收数量、字节数等。通过读取这些计数器&#xff0c;我们可以获得关于网络流量的信息。 为什么需要这些信息? 可以使用这些信息来监控网络…

Java建筑工程建设智慧工地源码

智慧工地管理平台依托物联网、互联网&#xff0c;建立云端大数据管理平台&#xff0c;形成“端云大数据”的业务体系和新的管理模式&#xff0c;从施工现场源头抓起&#xff0c;最大程度的收集人员、安全、环境、材料等关键业务数据&#xff0c;打通从一线操作与远程监管的数据…

C //练习 5-4 编写函数strend(s, t)。如果字符串t出现在字符串s的尾部,该函数返回1;否则返回0。

C程序设计语言 &#xff08;第二版&#xff09; 练习 5-4 练习 5-4 编写函数strend(s, t)。如果字符串t出现在字符串s的尾部&#xff0c;该函数返回1&#xff1b;否则返回0。 注意&#xff1a;代码在win32控制台运行&#xff0c;在不同的IDE环境下&#xff0c;有部分可能需要…

C++11_lambda表达式

文章目录 一、lambda表达式1.lambda的组成2.[capture-list] 的其他使用方法2.1混合捕捉 二、lambda表达式的使用场景1.替代仿函数 总结 一、lambda表达式 lambda表达式是C11新引入的功能&#xff0c;它的用法与我们之前学过的C语法有些不同。 1.lambda的组成 [capture-list] …

【Java万花筒】解锁Java并发之门:深入理解多线程

Java并发编程艺术&#xff1a;深度剖析多线程利器 前言 在当今软件开发的世界中&#xff0c;多线程编程已经变得愈发重要。面对多核处理器的普及和复杂的系统架构&#xff0c;开发人员需要深入了解并发编程的原理和实践&#xff0c;以充分发挥硬件的性能潜力。本文将带您深入…

【C++进阶05】AVL树的介绍及模拟实现

一、AVL树的概念 二叉搜索树的缺点 二叉搜索树虽可以缩短查找效率 但如果数据有序或接近有序 二叉搜索树将退化为单支树 查找元素相当于在顺序表中搜索元素&#xff0c;效率低下 AVL树便是解决此问题 向二叉搜索树中插入新结点 并保证每个结点的左右子树 高度之差的绝对值不超…

新能源汽车的三电指的是什么,作用有什么区别。

问题描述&#xff1a;新能源汽车的三电指的是什么&#xff0c;作用有什么区别。 问题解答&#xff1a; "新能源汽车的三电"通常指的是新能源汽车中的三大核心电气系统&#xff0c;包括&#xff1a;高压电池系统、电动机驱动系统和电子控制系统。这三者协同工作&…

Java诊断利器Arthas

https://arthas.aliyun.com/doc/https://arthas.aliyun.com/doc/ 原理 利用java.lang.instrument(容器类) 做动态 Instrumentation(执行容器) 是 JDK5 的新特性。使用 Instrumentation&#xff0c;开发者可以构建一个独立于应用程序的代理程序&#xff08;Agent&#xff09;&…