nnunetv2系列:使用默认的预测类推理2D数据
这里参考源代码nnUNet/nnunetv2/inference/predict_from_raw_data.py
中给的示例进行调整和测试。
代码示例
from torch import device
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor# from nnunetv2.paths import (
# nnUNet_results,
# # nnUNet_raw
# )
# from batchgenerators.utilities.file_and_folder_operations import joinfrom time import timeif __name__ == "__main__":start = time()# instantiate the nnUNetPredictorpredictor = nnUNetPredictor(tile_step_size=0.5,use_gaussian=True,use_mirroring=True,perform_everything_on_device=True,device=device("cuda", 0),verbose=False,verbose_preprocessing=False,allow_tqdm=True,)# initializes the network architecture, loads the checkpointpredictor.initialize_from_trained_model_folder(# 直接使用绝对路径,替换join方法"/home/bio/family/segmenation/nnUNet/nnUNet_results/Dataset500_ScleraIrisSegmentation/nnUNetTrainer__nnUNetPlans__2d",# join(# nnUNet_results,# "Dataset500_ScleraIrisSegmentation/nnUNetTrainer__nnUNetPlans__2d"# ),use_folds=(0,),checkpoint_name="checkpoint_best.pth",)# variant 1: give input and output folders# 使用绝对路径,否则会报错# 推荐内部注释生成json文件的代码,否则默认会生成json文件predictor.predict_from_files(# 实际测试发现,必须先转成nnunet格式,再进行预测,数据名称应该为*_0000.png这类的"/home/bio/family/segmenation/nnUNet/afamily_test/inference/imagesTr","/home/bio/family/segmenation/nnUNet/afamily_test/inference/imagesTr_predict",save_probabilities=False,overwrite=False,num_processes_preprocessing=1,num_processes_segmentation_export=1,folder_with_segs_from_prev_stage=None,num_parts=1,part_id=0,)print(f"Time taken: {time() - start}")