前言
对vit 进行fp16推理
参考链接:
https://github.com/open-mmlab/mmpretrain/tree/master/configs/t2t_vit
run code :
https://mmclassification.readthedocs.io/en/latest/getting_started.html#inference-and-test-a-dataset
https://mmclassification.readthedocs.io/en/latest/tools/pytorch2onnx.html?highlight=onnx
run code and convert onnx
# run acc
python3 tools/test.py configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py ./t2t-vit-t-14_8xb64_in1k_20211220-f7378dd5.pth --metrics accuracy --out result.pkl 2>&1 | tee test_run.log# convert onnx , bs=1
pip install onnx onnxsim
python3 tools/deployment/pytorch2onnx.py \configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py \--checkpoint ./t2t-vit-t-14_8xb64_in1k_20211220-f7378dd5.pth \--output-file ./t2t-vit-t-14.onnx 2>&1 | tee test_onnx.log# simplify
onnxsim t2t-vit-t-14.onnx t2t-vit-t-14-sim.onnx # onnx bs=64
修改文件中的input data format,或者dynamic axes, 详细见后面介绍
test infer use onnxruntime
import onnxruntime
import onnx
import torch
import numpy as np# 检查onnx计算图
def checknet():model = onnx.load("./t2t-vit-t-14_bs64-sim.onnx")onnx.checker.check_model(model) # Print a human readable representation of the graph# print(onnx.helper.printable_graph(model.graph))def runonnx():image = torch.randn(64, 3, 224, 224).numpy().astype(np.float32)session = onnxruntime.InferenceSession("./t2t-vit-t-14_bs64-sim.onnx")# session.get_modelmeta()input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].nameprint('Input Name:', input_name)print('Output Name:', output_name)output2 = session.run([output_name], {input_name: image})print(output2[0].shape)if __name__ == '__main__':checknet()runonnx()
save node pred use onnxruntime
import onnxruntime
import onnx
import torch
import numpy as np
import onnx.helper as helper
import osdef get_onnx_node():base_path = "./"onnx_file = os.path.join(base_path,"t2t-vit-t-14_bs64-sim.onnx")save_onnx = os.path.join(base_path,"t2t-vit-t-14_bs64-sim-out.onnx")model = onnx.load(onnx_file)out_names=[]for i, node in enumerate(model.graph.node):out_names.append(node.output[0])for out_name in out_names:intermediate_layer_value_info = helper.ValueInfoProto()intermediate_layer_value_info.name = out_namemodel.graph.output.append(intermediate_layer_value_info)onnx.save(model, save_onnx)def get_onnx_layer_out():base_path = "./"onnx_file = os.path.join(base_path,"t2t-vit-t-14_bs64-sim-out.onnx")OutDir = "./onnx_file/"image = torch.randn(64, 3, 224, 224).numpy().astype(np.float32)session = onnxruntime.InferenceSession(onnx_file)input_name = session.get_inputs()[0].name outputs = [x.name for x in session.get_outputs()]preds = session.run(outputs, {input_name: image})for name,value in zip(outputs, preds):file = OutDir + name + ".npy"np.save(file, value, allow_pickle=True, fix_imports=True)if __name__ == '__main__':get_onnx_node()get_onnx_layer_out()