转化函数
# 可以在https://github.com/ultralytics/yolov5/blob/master/export.py里面找到
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrtassert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'try:import tensorrt as trtexcept Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trtif trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify) # opset 12model.model[-1].anchor_grid = gridelse: # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify) # opset 12onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine') # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation noticeflag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')if dynamic:if im.shape[0] <= 1:LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())return f, None
步骤 1: 导入库和检查 GPU 可用性
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:import tensorrt as trt
except Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trt
- 确保模型在 GPU 上运行,如果在 CPU 上运行,抛出异常。
- 尝试导入
tensorrt
库,如果失败并且系统是 Linux,通过check_requirements
函数安装nvidia-tensorrt
。 - 再次尝试导入
tensorrt
库。
步骤 2: 处理 TensorRT 版本 7 的兼容性
if trt.__version__[0] == '7':grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify) # opset 12model.model[-1].anchor_grid = grid
else:check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
- 如果 TensorRT 版本是 7,调整 YOLOv5 模型的锚点网格,导出 ONNX 文件,然后恢复原始的锚点网格。
- 如果 TensorRT 版本大于等于 8,检查 TensorRT 版本是否满足要求(至少 8.0.0),然后导出 ONNX 文件。
步骤 3: 将模型导出为 ONNX 格式
onnx = file.with_suffix('.onnx')
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
- 指定 ONNX 文件的路径,并调用
export_onnx
函数将 YOLOv5 模型导出为 ONNX 格式。
步骤 4: 初始化 TensorRT 组件
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT 引擎文件
logger = trt.Logger(trt.Logger.INFO)
- 记录 TensorRT 版本信息。
- 确保 ONNX 文件存在。
- 指定 TensorRT 引擎文件的路径。
- 初始化 TensorRT 的日志记录器。
步骤 5: 创建 TensorRT 构建器和配置
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
- 创建 TensorRT 构建器。
- 创建构建器配置对象。
- 配置最大工作空间大小。
补充说明:
config.max_workspace_size = workspace * 1 << 30
这行代码设置了 TensorRT 构建配置对象 config 的最大工作空间大小max_workspace_size:
-
1 << 30
表示将二进制数1
左移 30 位。在计算机中,左移操作相当于乘以 2 的指定次方。因此,1 << 30
相当于 2 的 30 次方,即 2^30。 -
将
workspace
乘以 2^30 就是将其转换为字节。这是因为在计算机存储中,通常使用字节为基本单位。
在这里,workspace * 1 << 30
计算出的值将工作空间大小设置为 workspace GB。你可以根据系统的内存情况和模型的复杂性调整此值,以确保在构建 TensorRT 引擎时有足够的内存可用。
步骤 6: 创建 TensorRT 网络和 ONNX 解析器
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')
- 创建 TensorRT 网络,启用显式批处理。
- 使用 ONNX 解析器解析 ONNX 文件,构建 TensorRT 网络。
补充说明:
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
这里创建了一个标志 flag
,使用位运算左移的方式将 1
移动到 trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
这个标志所表示的位置上。这个标志表示在创建网络时使用显式批处理。
步骤 7: 显示输入和输出信息
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
- 获取 TensorRT 网络的输入和输出信息。
- 打印输入和输出的名称、形状和数据类型。
步骤 8: 处理动态 TensorRT 优化
if dynamic:if im.shape[0] <= 1:LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)
- 如果启用动态优化,创建优化配置文件。
- 设置输入的形状,以便在不同批次大小下进行优化。
补充说明:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
用于设置 TensorRT 动态优化配置文件的输入形状。让我们逐步解释这行代码:
profile
是 TensorRT 中的优化配置文件(trt.OptimizationProfile
)。inp.name
是当前输入张量的名称。(1, *im.shape[1:])
设置了最小的输入形状,其中批次大小(batch size)为 1,其余维度与im
的形状相同。(max(1, im.shape[0] // 2), *im.shape[1:])
设置了最大的输入形状,其中批次大小(batch size)为im.shape[0] // 2
,其余维度与im
的形状相同。im.shape
是当前输入张量的形状。
这行代码的目的是为动态 TensorRT 模型创建一个优化配置文件,并设置输入形状的范围,以便在运行时适应不同批次大小的输入。这对于处理动态批次大小的模型非常有用,允许模型在训练和推理中适应不同大小的输入数据。
步骤 9: 构建 TensorRT 引擎
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())
- 记录正在构建的 TensorRT 引擎的精度信息(FP16 或 FP32)。
- 如果支持 FP16 且指定使用 FP16,则设置相应标志。
- 使用构建器、配置和网络构建 TensorRT 引擎。
- 将引擎序列化并写入指定的文件。
步骤 10: 返回引擎文件路径
return f, None
- 最终,函数返回 TensorRT 引擎文件的路径。