安装tensorRT:
1、下载与电脑中cuda和cudnn版本对应的tensorRT(比如我的是TensorRT-8.2.1.8.Windows10.x86_64.cuda-11.4.cudnn8.2)
2、打开目录里面有python文件夹,找到对应python版本的whl文件(我的是tensorrt-8.2.1.8-cp38-none-win_amd64.whl) 因为我python安装的是3.8版本
3、终端安装:pip install tensorrt-8.2.1.8-cp38-none-win_amd64.whl
4、结束
import tensorrt as trt
def get_DynEngine(onnx_file_path, engine_file_path,patchsize,max_workspace_size,max_batch_size):'''Attempts to load a serialized engine if available,otherwise build a new TensorRT engine as save it'''TRT_LOGGER = trt.Logger()trt.init_libnvinfer_plugins(TRT_LOGGER, "")explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)builder = trt.Builder(TRT_LOGGER)network = builder.create_network(explicit_batch)config = builder.create_builder_config()parser = trt.OnnxParser(network, TRT_LOGGER)runtime = trt.Runtime(TRT_LOGGER)print("common.EXPLICIT_BATCH:", explicit_batch)# 最大内存占用# 显存溢出需要重新设置config.max_workspace_size = max_workspace_size # 256MBconfig.set_flag(trt.BuilderFlag.FP16)print("max_workspace_size:", config.max_workspace_size)builder.max_batch_size = max_batch_size # 推理的时候要保证batch_size<=max_batch_sizeif not os.path.exists(onnx_file_path):print(f'onnx file {onnx_file_path} not found,please run torch_2_onnx.py first to generate it')exit(0)print(f'Loading ONNX file from path {onnx_file_path}...')with open(onnx_file_path, 'rb') as model:print('Beginning ONNX file parsing')if not parser.parse(model.read()):print('ERROR:Failed to parse the ONNX file')for error in range(parser.num_errors):print(parser.get_error(error))return Noneinputs = [network.get_input(i) for i in range(network.num_inputs)]print("input", inputs)outputs = [network.get_output(i) for i in range(network.num_outputs)]print("out:", outputs)print("Network Description")for input in inputs:# 获取当前转化之前的 输入的 batch_sizebatch_size = input.shape[0]print("Input '{}' with shape {} and dtype {} . ".format(input.name, input.shape, input.dtype))for output in outputs:print("Output '{}' with shape {} and dtype {} . ".format(output.name, output.shape, output.dtype))# Dynamic input setting 动态输入在builder的profile设置# 为每个动态输入绑定一个profileprofile = builder.create_optimization_profile()print("network.get_input(0).name:", network.get_input(0).name)profile.set_shape(network.get_input(0).name, (1,1, *patchsize), (1, 1,*patchsize),(max_batch_size, 1, *patchsize)) # 最小的尺寸,常用的尺寸,最大的尺寸,推理时候输入需要在这个范围内config.add_optimization_profile(profile)print('Completed parsing the ONNX file')print(f'Building an engine from file {onnx_file_path}; this may take a while...')engine = builder.build_serialized_network(network, config)print('Completed creating Engine')with open(engine_file_path, 'wb') as f:f.write(engine)return engineif __name__ == "__main__":get_DynEngine("1.onnx", "2.engine",[96,160,160],5*(1<<30),2)