目录
- 1. 环境准备
- 1.1 安装tensorrt
- 1.1.1 pip安装
- 1.1.2 压缩包安装
- 2. pt转engine
- 3. 转换过程中遇到的问题
1. 环境准备
1.1 安装tensorrt
1.1.1 pip安装
pip install tensorrt
1.1.2 压缩包安装
很可能会失败,最保险的方法是下载tensorRT的压缩包,比如:下载TensorRT-8.4.3.1
文件结构:
安装
pip install tensorrt-8.4.3.1-cp37-none-win_amd64.whl
配置环境变量
2. pt转engine
进入yolov5-7.0目录下(注意:必须是6.1以上,6.0的export.py不支持)
python export.py --weights E:\code\other\tph-yolov5-main\runs\train\exp\weights\last.pt --include engine --imgsz 1536 --device 0 --half
3. 转换过程中遇到的问题
使用yolov5-6.1的export.py转换模型时,有时候执行2中的脚本,转出来的engine竟然识别不出任何目标,于是尝试先转onnx,再转engine,同时注释export.py中的export_engine函数中转onnx的代码,结果竟然成功了。
如下:
代码:
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrttry:check_requirements(('tensorrt',))if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, train, False, simplify) # opset 12model.model[-1].anchor_grid = gridelse: # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0#export_onnx(model, im, file, 13, train, False, simplify) # opset 13onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine') # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]LOGGER.info(f'{prefix} Network Description:')for inp in inputs:LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')half &= builder.platform_has_fast_fp16LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')if half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')return fexcept Exception as e:LOGGER.info(f'\n{prefix} export failure: {e}')