Skip to content

Commit 511204e

Browse files
predict_from_raw_data.py: Add trainer_class argument
1 parent 1ea6759 commit 511204e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

nnunetv2/inference/predict_from_raw_data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __init__(self,
6666

6767
def initialize_from_trained_model_folder(self, model_training_output_dir: str,
6868
use_folds: Union[Tuple[Union[int, str]], None],
69-
checkpoint_name: str = 'checkpoint_final.pth'):
69+
checkpoint_name: str = 'checkpoint_final.pth',
70+
trainer_class: Optional["nnUNetPredictor"] = None,):
7071
"""
7172
This is used when making predictions with a trained model
7273
"""
@@ -96,11 +97,12 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
9697
configuration_manager = plans_manager.get_configuration(configuration_name)
9798
# restore network
9899
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
99-
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
100-
trainer_name, 'nnunetv2.training.nnUNetTrainer')
101100
if trainer_class is None:
102-
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
103-
f'Please place it there (in any .py file)!')
101+
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
102+
trainer_name, 'nnunetv2.training.nnUNetTrainer')
103+
if trainer_class is None:
104+
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
105+
f'Please place it there (in any .py file)!')
104106
network = trainer_class.build_network_architecture(
105107
configuration_manager.network_arch_class_name,
106108
configuration_manager.network_arch_init_kwargs,

0 commit comments

Comments
 (0)