@@ -66,7 +66,8 @@ def __init__(self,
6666
6767 def initialize_from_trained_model_folder (self , model_training_output_dir : str ,
6868 use_folds : Union [Tuple [Union [int , str ]], None ],
69- checkpoint_name : str = 'checkpoint_final.pth' ):
69+ checkpoint_name : str = 'checkpoint_final.pth' ,
70+ trainer_class : Optional ["nnUNetPredictor" ] = None ,):
7071 """
7172 This is used when making predictions with a trained model
7273 """
@@ -96,11 +97,12 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
9697 configuration_manager = plans_manager .get_configuration (configuration_name )
9798 # restore network
9899 num_input_channels = determine_num_input_channels (plans_manager , configuration_manager , dataset_json )
99- trainer_class = recursive_find_python_class (join (nnunetv2 .__path__ [0 ], "training" , "nnUNetTrainer" ),
100- trainer_name , 'nnunetv2.training.nnUNetTrainer' )
101100 if trainer_class is None :
102- raise RuntimeError (f'Unable to locate trainer class { trainer_name } in nnunetv2.training.nnUNetTrainer. '
103- f'Please place it there (in any .py file)!' )
101+ trainer_class = recursive_find_python_class (join (nnunetv2 .__path__ [0 ], "training" , "nnUNetTrainer" ),
102+ trainer_name , 'nnunetv2.training.nnUNetTrainer' )
103+ if trainer_class is None :
104+ raise RuntimeError (f'Unable to locate trainer class { trainer_name } in nnunetv2.training.nnUNetTrainer. '
105+ f'Please place it there (in any .py file)!' )
104106 network = trainer_class .build_network_architecture (
105107 configuration_manager .network_arch_class_name ,
106108 configuration_manager .network_arch_init_kwargs ,
0 commit comments