yolov5 torch转tensorrt详解【推荐】

转化函数

# 可以在https://github.com/ultralytics/yolov5/blob/master/export.py里面找到
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrtassert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'try:import tensorrt as trtexcept Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trtif trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12model.model[-1].anchor_grid = gridelse:  # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine')  # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation noticeflag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')if dynamic:if im.shape[0] <= 1:LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())return f, None

步骤 1: 导入库和检查 GPU 可用性

assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:import tensorrt as trt
except Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trt
  • 确保模型在 GPU 上运行,如果在 CPU 上运行,抛出异常。
  • 尝试导入 tensorrt 库,如果失败并且系统是 Linux,通过 check_requirements 函数安装 nvidia-tensorrt
  • 再次尝试导入 tensorrt 库。

步骤 2: 处理 TensorRT 版本 7 的兼容性

if trt.__version__[0] == '7':grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12model.model[-1].anchor_grid = grid
else:check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
  • 如果 TensorRT 版本是 7,调整 YOLOv5 模型的锚点网格,导出 ONNX 文件,然后恢复原始的锚点网格。
  • 如果 TensorRT 版本大于等于 8,检查 TensorRT 版本是否满足要求(至少 8.0.0),然后导出 ONNX 文件。

步骤 3: 将模型导出为 ONNX 格式

onnx = file.with_suffix('.onnx')
export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12
  • 指定 ONNX 文件的路径,并调用 export_onnx 函数将 YOLOv5 模型导出为 ONNX 格式。

步骤 4: 初始化 TensorRT 组件

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine')  # TensorRT 引擎文件
logger = trt.Logger(trt.Logger.INFO)
  • 记录 TensorRT 版本信息。
  • 确保 ONNX 文件存在。
  • 指定 TensorRT 引擎文件的路径。
  • 初始化 TensorRT 的日志记录器。

步骤 5: 创建 TensorRT 构建器和配置

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
  • 创建 TensorRT 构建器。
  • 创建构建器配置对象。
  • 配置最大工作空间大小。
补充说明:
config.max_workspace_size = workspace * 1 << 30

这行代码设置了 TensorRT 构建配置对象 config 的最大工作空间大小max_workspace_size:

  • 1 << 30 表示将二进制数 1 左移 30 位。在计算机中,左移操作相当于乘以 2 的指定次方。因此,1 << 30 相当于 2 的 30 次方,即 2^30。

  • workspace 乘以 2^30 就是将其转换为字节。这是因为在计算机存储中,通常使用字节为基本单位。

在这里,workspace * 1 << 30 计算出的值将工作空间大小设置为 workspace GB。你可以根据系统的内存情况和模型的复杂性调整此值,以确保在构建 TensorRT 引擎时有足够的内存可用。

步骤 6: 创建 TensorRT 网络和 ONNX 解析器

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')
  • 创建 TensorRT 网络,启用显式批处理。
  • 使用 ONNX 解析器解析 ONNX 文件,构建 TensorRT 网络。
补充说明:
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

这里创建了一个标志 flag,使用位运算左移的方式将 1 移动到 trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH 这个标志所表示的位置上。这个标志表示在创建网络时使用显式批处理。

步骤 7: 显示输入和输出信息

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
  • 获取 TensorRT 网络的输入和输出信息。
  • 打印输入和输出的名称、形状和数据类型。

步骤 8: 处理动态 TensorRT 优化

if dynamic:if im.shape[0] <= 1:LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)
  • 如果启用动态优化,创建优化配置文件。
  • 设置输入的形状,以便在不同批次大小下进行优化。
补充说明:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)

用于设置 TensorRT 动态优化配置文件的输入形状。让我们逐步解释这行代码:

  • profile 是 TensorRT 中的优化配置文件(trt.OptimizationProfile)。
  • inp.name 是当前输入张量的名称。
  • (1, *im.shape[1:]) 设置了最小的输入形状,其中批次大小(batch size)为 1,其余维度与 im 的形状相同。
  • (max(1, im.shape[0] // 2), *im.shape[1:]) 设置了最大的输入形状,其中批次大小(batch size)为 im.shape[0] // 2,其余维度与 im 的形状相同。
  • im.shape 是当前输入张量的形状。

这行代码的目的是为动态 TensorRT 模型创建一个优化配置文件,并设置输入形状的范围,以便在运行时适应不同批次大小的输入。这对于处理动态批次大小的模型非常有用,允许模型在训练和推理中适应不同大小的输入数据。

步骤 9: 构建 TensorRT 引擎

LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())
  • 记录正在构建的 TensorRT 引擎的精度信息(FP16 或 FP32)。
  • 如果支持 FP16 且指定使用 FP16,则设置相应标志。
  • 使用构建器、配置和网络构建 TensorRT 引擎。
  • 将引擎序列化并写入指定的文件。

步骤 10: 返回引擎文件路径

return f, None
  • 最终,函数返回 TensorRT 引擎文件的路径。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/674980.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MVCC多版本并发控制

MVCC mvcc是多版本并发控制。维护一个数据的多个版本&#xff0c;使读写没有冲突 隐式字段 DB_TRX_ID:最近修改事务id&#xff0c;记录插入这条记录或最后一次修改该记录的事务id DB_ROLL_PTR:回滚指针&#xff0c;指向这条记录的上一个版本&#xff0c;用于配合undo log&…

筛法思想的题目

这道题目比较经典&#xff0c;或者说这种思想比较经典。 这种筛法的思想。 我们正着想对于每一个 n 、 n − 1 、 n − 2 、 . . . 、 2 、 1 n、 n-1、n-2、...、2、1 n、n−1、n−2、...、2、1都分解一遍质因数显然是来不及的时间复杂度达到 O ( n n ) O(n \sqrt{n}) O(nn ​…

Open CASCADE学习|点和曲线的相互转化

目录 1、把曲线离散成点 1.1按数量离散 1.2按长度离散 1.3按弦高离散 2、由点合成曲线 2.1B样条插值 2.2B样条近似 1、把曲线离散成点 计算机图形学中绘制曲线&#xff0c;无论是绘制参数曲线还是非参数曲线&#xff0c;都需要先将参数曲线进行离散化&#xff0c;通过离…

LayUI中表格树折叠 --

1、先将插件源码进行下载&#xff0c;新建 tableTree.js 文件&#xff0c;将源码放进去 2、将 tableTree.js 文件 配置之后&#xff0c;在需要使用的页面进行引入&#xff1a; layui.define(["tableTree"],function (exports) {var tableTree layui.tableTree;// …

2024年【天津市安全员B证】模拟试题及天津市安全员B证模拟考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 天津市安全员B证模拟试题是安全生产模拟考试一点通生成的&#xff0c;天津市安全员B证证模拟考试题库是根据天津市安全员B证最新版教材汇编出天津市安全员B证仿真模拟考试。2024年【天津市安全员B证】模拟试题及天津市…

蓝桥杯官网练习题(大臣的旅费)

问题描述 很久以前&#xff0c;T 王国空前繁荣。为了更好地管理国家&#xff0c;王国修建了大量的快速路&#xff0c;用于连接首都和王国内的各大城市。 为节省经费&#xff0c;T 国的大臣们经过思考&#xff0c;制定了一套优秀的修建方案&#xff0c;使得任何一个大城市都能从…

Redis——缓存设计与优化

讲解Redis的缓存设计与优化&#xff0c;以及在生产环境中遇到的Redis常见问题&#xff0c;例如缓存雪崩和缓存穿透&#xff0c;还讲解了相关问题的解决方案。 1、Redis缓存的优点和缺点 1.1、缓存优点&#xff1a; 高速读写&#xff1a;Redis可以帮助解决由于数据库压力造成…

安全的接口访问策略

渗透测试 一、Token与签名 一般客户端和服务端的设计过程中&#xff0c;大部分分为有状态和无状态接口。 一般用户登录状态下&#xff0c;判断用户是否有权限或者能否请求接口&#xff0c;都是根据用户登录成功后&#xff0c;服务端授予的token进行控制的。 但并不是说有了tok…

【LeetCode】332. 重新安排行程(困难)——代码随想录算法训练营Day30

题目链接&#xff1a;332. 重新安排行程 题目描述 给你一份航线列表 tickets &#xff0c;其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票都属于一个从 JFK&#xff08;肯尼迪国际机场&#xff09;出发的先生&a…

椭圆曲线加密

椭圆曲线加密&#xff08;Elliptic Curve Cryptography&#xff0c;ECC&#xff09;是一种公钥加密算法&#xff0c;它基于椭圆曲线上的数学运算来实现安全的通信。 以下是椭圆曲线加密的基本过程&#xff1a; 1. 参数选择&#xff1a;选择一个适当的椭圆曲线和一个基础点。椭…

C#(C Sharp)学习笔记_运算符与布尔类型【四】

算术运算符 所谓算术运算符&#xff1a;就是现实中的加减乘除之类的符号&#xff0c;但在编程语言中&#xff0c;它们又有不同于现实的语法。下面就介绍一下算术运算符的各种符号包括计算案例。 运算符描述实例(设a为4&#xff1b;b为2)把两个操作数相加A B 将得到 6-从第一…

JVM-运行时数据区程序计数器

运行时数据区 Java虚拟机在运行Java程序过程中管理的内存区域&#xff0c;称之为运行时数据区。《Java虚拟机规范》中规定了每一部分的作用。 程序计数器的定义 程序计数器&#xff08;Program Counter Register&#xff09;也叫PC寄存器&#xff0c;每个线程会通过程序计数器…

1.3 Verilog 环境搭建详解教程

学习 Verilog 做仿真时&#xff0c;可选择不同仿真环境。FPGA 开发环境有 Xilinx 公司的 ISE&#xff08;目前已停止更新&#xff09;&#xff0c;VIVADO&#xff1b;因特尔公司的 Quartus II&#xff1b;ASIC 开发环境有 Synopsys 公司的 VCS &#xff1b;很多人也在用 Icarus…

PyTorch 2.2 中文官方教程(三)

使用 PyTorch 构建模型 原文&#xff1a;pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 注意 点击这里下载完整示例代码 介绍 || 张量 || 自动微分 || 构建模型 || TensorBoard 支持 || 训练模型 ||…

一些学习的总结帖子

一、Spring 参考链接1 参考链接2 参考链接3 二、多线程 并发的理解 参考链接1 三、redis 参考链接1 四、rabbitmq 五、数据库 数据库事务的概念及其原理 数据库事务 六、other 添加链接描述

Why React Doesn‘t Need jQuery?

a revolution library – 一个革命性的库greatly simplified tasks such as … – 极大的简化了…任务DOM manipulation – DOM操作event handling – 事件处理animation creation – 动画创建Ajax request – Ajax请求with the rise of modern front frameworks – 随着现代前…

Java风暴:打造高效作家信息管理平台

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

Web项目利用EasyExcel实现Excel的导出操作

早期Java使用的一些解析&#xff0c;到处excel的框架存在种种问题被遗弃&#xff0c;现在使用阿里巴巴所提供的EasyExcel已成为一种主流&#xff0c;本篇将详细介绍该功能在Web项目中如何实际应用。 详细操作文档&#xff1a;写Excel | Easy Excel 一、项目演示 在后台管理界…

windows下使用bat打开程序,并解决闪退问题

1.如何使用bat打开一个已经编译好的exe文件 示例&#xff1a;start /d"F:\testProject\bin\Debug" Shell_Component.exestart 空格 /d(后面不要空格) 引号并包裹exe程序路径 空格 exe名称 参考&#xff1a;https://blog.csdn.net/zhangshengqiang168/article/d…

Nginx与history路由模式:刷新页面404问题

使用nginx部署前端项目&#xff0c;路由模式采用history模式时&#xff0c;刷新页面之后&#xff0c;显示404。 路由模式 前端路由的基本作用为&#xff1a; ①当浏览器地址变化时&#xff0c;切换页面&#xff1b; ②点击浏览器后退、前进按钮时&#xff0c;更新网页内容&…