前言
构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构
本文主要是简单学习和使用两种不同onnx结构,
下面以gather
结点进行分析
采用onnx方式构建
from onnx import TensorProto, helper
import onnx
import onnxruntime
import numpy as npdef construct_model_gather(input_shape0,input_shape1, axis, output_shape, indices):print("construct model start... \n")initializer = [helper.make_tensor("indices", TensorProto.INT64, input_shape1, indices)]Gather = helper.make_node("Gather", inputs=["data", "indices"], outputs=["output"], name="Gather_test")Gather.attribute.extend([onnx.helper.make_attribute("axis", axis)])graph = helper.make_graph(nodes=[Gather],name="test_graph",inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, tuple(input_shape0))], # use your inputoutputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, tuple(output_shape))],initializer=initializer,)opset_imports = [onnx.helper.make_operatorsetid("", version=12)]model = onnx.helper.make_model(graph, opset_imports=opset_imports)model.ir_version = onnx.IR_VERSIONonnx.checker.check_model(model)print("construct model done... \n")return modeldef run(model_file, input_data1):print(f'run start....\n')session = onnxruntime.InferenceSession(model_file, providers=['CPUExecutionProvider'])input_name1 = session.get_inputs()[0].name print(f'input_data1:{input_data1}')pred_onx = session.run( None, {input_name1: input_data1})print(f'pred_onx:{pred_onx}\n')print(f'pred_onx[0].shape:{pred_onx[0].shape}\n')print(f'run done....\n')return pred_onx[0]if __name__ == '__main__':test_cases = [([3, 3], [1, 2], 1, [3, 1, 2]),([8, 32, 51, 80], [15, 37], 2, [8, 32, 15, 37, 80]),([8, 32, 15, 37, 80], [15, 66], 4, [8, 32, 15, 37, 15, 66]),]for input_shape0, input_shape1, axis, output_shape in test_cases:index_max = input_shape0[axis]indices = np.random.randint(index_max, size=tuple(input_shape1)).astype(np.int64)model = construct_model_gather(input_shape0,input_shape1, axis, output_shape, indices)model_file = "test_gather_normal.onnx"onnx.save(model, model_file)input_data1= np.random.random(tuple(input_shape0)).astype(np.float32)onnx_output = run(model_file, input_data1)np_out = np.take(input_data1, indices, axis=axis)diff = np.abs(np_out - onnx_output).max()print(f"test_Gather input_shape: {input_shape0} ,shape:{input_shape1}, axis:{axis} max diff:{diff}\n")