Pytorch模型转ONNX模型
import torch
import torch. nn as nn
from backbone import OXI_Netmodel = OXI_Net( )
model. load_state_dict( torch. load( './cnn_model_50.pth' ) )
model. eval ( ) input_names = [ 'image' ]
output_names = [ 'label' ] x = torch. randn( 1 , 3 , 224 , 224 ) torch. onnx. export( model, x, './cnn_model_50.onnx' , input_names= input_names, output_names= output_names) print ( "ONNX模型导出成功!" )
使用ONNXRuntime运行ONNX模型
import onnxruntime
import numpy as np
import torchvision. transforms as transforms
import torchvision. transforms. functional as functional
from PIL import Image
onnx_model_path = "./cnn_model_50.onnx"
session = onnxruntime. InferenceSession( onnx_model_path)
input_name = session. get_inputs( ) [ 0 ] . name
output_name = session. get_outputs( ) [ 0 ] . name
image = Image. open ( './idcard.bmp' )
image = functional. crop( image, left= 0 , top= 0 , width= 648 , height= 648 )
transform = transforms. Compose( [ transforms. Resize( 224 ) , transforms. RandomRotation( 10 ) , transforms. ToTensor( ) , transforms. Normalize( ( 0.485 , 0.456 , 0.406 ) , ( 0.229 , 0.224 , 0.225 ) ) ] ) input_data = transform( image)
input_data = input_data. unsqueeze( 0 )
input_data = np. array( input_data)
output = session. run( [ output_name] , { input_name: input_data} )
predicted_result = output[ 0 ]
predicted_class = np. argmax( predicted_result) print ( "预测的结果为:" , predicted_result)
print ( "预测的类别索引为:" , predicted_class)