nnunet v2版本的模型转化为onnx
转化为onnx后可以转化为engine文件,方便在c++使用;(可以移步tensorRT分类中看)
import torch
from nnunetv2.inference.predict_from_raw_data import load_what_we_need
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments
import json
def load_json(file: str):with open(file, 'r') as f:a = json.load(f)return a
from preprocess.run_training import run_training
import osif __name__ == "__main__":dataset_id =9task_name = "vessel"checkpoint_name = 'state_dict.pth'configuration = '3d_fullres'fold = '1'patch_size=[90,160,160]dataset_name = convert_id_to_dataset_name(dataset_id)model_training_output_dir = os.path.join(nnUNet_results, dataset_name)checkpoint = torch.load(os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', checkpoint_name),map_location=torch.device('cpu'))dataset_json_file = os.path.join(model_training_output_dir, 'dataset.json')dataset_fingerprint_file = os.path.join(model_training_output_dir, 'dataset_fingerprint.json')plan_file = os.path.join(model_training_output_dir, 'plan.json')dataset_json = load_json(dataset_json_file)dataset_fingerprint = load_json(dataset_fingerprint_file)plan = plan_experiments(dataset_id, dataset_json, dataset_fingerprint, gpu_memory_target_in_gb=8,overwrite_plans_name='nnUNetPlans')parameters, configuration_manager, inference_allowed_mirroring_axes, \plans_manager, network, trainer_name = \load_what_we_need(model_training_output_dir, dataset_id, configuration, fold, checkpoint_name, dataset_json, plan)f = int(fold) if fold != 'all' else foldnnunet_trainer = run_training(dataset_id, configuration, f, plan, dataset_json)if not nnunet_trainer.was_initialized:nnunet_trainer.initialize()net=nnunet_trainer.networknet.load_state_dict(checkpoint["models_state_dict"][0])net.eval()dummy_input = torch.randn(1, 1, *patch_size)#.to("cuda")torch.onnx.export(net,dummy_input,os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', f'{task_name}.onnx'),input_names=['input'],output_names=['output'],dynamic_axes = {'input': {0: 'batch_size'},'output': {0: 'batch_size'}})