diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index e358cd4b99..7672d977a0 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -133,7 +133,7 @@ def get_nnunet_trainer( cudnn.benchmark = True if pretrained_model is not None: - state_dict = torch.load(pretrained_model) + state_dict = torch.load(pretrained_model, weights_only=False) if "network_weights" in state_dict: nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) return nnunet_trainer @@ -152,6 +152,12 @@ class ModelnnUNetWrapper(torch.nn.Module): The folder path where the model and related files are stored. model_name : str, optional The name of the model file, by default "model.pt". + dataset_json : dict, optional + The dataset JSON file containing dataset information. + plans : dict, optional + The plans JSON file containing model configuration. + nnunet_config : dict, optional + The nnUNet configuration dictionary containing model parameters. Attributes ---------- @@ -162,11 +168,19 @@ class ModelnnUNetWrapper(torch.nn.Module): Notes ----- - This class integrates nnUNet model with MONAI framework by loading necessary configurations, + This class integrates nnUNet model with MONAI framework by loading configurations, restoring network architecture, and setting up the predictor for inference. """ - def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore + def __init__( + self, + predictor: object, + model_folder: Union[str, Path], + model_name: str = "model.pt", + dataset_json: Optional[dict] = None, + plans: Optional[dict] = None, + nnunet_config: Optional[dict] = None, + ): # type: ignore super().__init__() self.predictor = predictor @@ -175,30 +189,43 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name from nnunetv2.utilities.plans_handling.plans_handler import PlansManager # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor - dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json")) - plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) + if dataset_json is None: + dataset_json = load_json(join(Path(model_training_output_dir).parent, "dataset.json")) + if plans is None: + plans = load_json(join(Path(model_training_output_dir).parent, "plans.json")) plans_manager = PlansManager(plans) parameters = [] - checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") - ) - trainer_name = checkpoint["trainer_name"] - configuration_name = checkpoint["init_args"]["configuration"] - inference_allowed_mirroring_axes = ( - checkpoint["inference_allowed_mirroring_axes"] - if "inference_allowed_mirroring_axes" in checkpoint.keys() - else None - ) - if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + if nnunet_config is None: + checkpoint = torch.load( + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), + map_location=torch.device("cpu"), + weights_only=False, + ) + trainer_name = checkpoint["trainer_name"] + configuration_name = checkpoint["init_args"]["configuration"] + inference_allowed_mirroring_axes = ( + checkpoint["inference_allowed_mirroring_axes"] + if "inference_allowed_mirroring_axes" in checkpoint.keys() + else None + ) + else: + trainer_name = nnunet_config["trainer_name"] + configuration_name = nnunet_config["configuration"] + inference_allowed_mirroring_axes = nnunet_config["inference_allowed_mirroring_axes"] + + if Path(model_training_output_dir).joinpath(model_name).is_file() and model_name.endswith(".pt"): + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=False + ) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: parameters.append(monai_checkpoint) configuration_manager = plans_manager.get_configuration(configuration_name) + import nnunetv2 from nnunetv2.utilities.find_class_by_name import recursive_find_python_class from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels @@ -255,7 +282,16 @@ def forward(self, x: MetaTensor) -> MetaTensor: """ if isinstance(x, MetaTensor): if "pixdim" in x.meta: - properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} + if x.meta["pixdim"].ndim == 1: + if x.meta["pixdim"][0] == 1: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][1:4].tolist()} + else: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][:3].tolist()} + else: + if x.meta["pixdim"][0][0] == 1: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()} + else: + properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][:3].numpy().tolist()} elif "affine" in x.meta: spacing = [ abs(x.meta["affine"][0][0].item()), @@ -269,6 +305,8 @@ def forward(self, x: MetaTensor) -> MetaTensor: raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.") image_or_list_of_images = x.cpu().numpy()[0, :] + image_or_list_of_images = np.transpose(image_or_list_of_images, (0, 3, 2, 1)) + properties_or_list_of_properties["spacing"] = properties_or_list_of_properties["spacing"][::-1] # input_files should be a list of file paths, one per modality prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore @@ -286,11 +324,17 @@ def forward(self, x: MetaTensor) -> MetaTensor: for out in prediction_output: # Add batch and channel dimensions out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension - + out_tensor = out_tensor.permute(0, 1, 4, 3, 2) return MetaTensor(out_tensor, meta=x.meta) -def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper: +def get_nnunet_monai_predictor( + model_folder: Union[str, Path], + model_name: str = "model.pt", + dataset_json: Optional[dict] = None, + plans: Optional[dict] = None, + nnunet_config: Optional[dict] = None, +) -> ModelnnUNetWrapper: """ Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. The model folder should contain the following files, created during training: @@ -321,6 +365,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = The folder where the model is stored. model_name : str, optional The name of the model file, by default "model.pt". + dataset_json : dict, optional + The dataset JSON file containing dataset information. + plans : dict, optional + The plans JSON file containing model configuration. + nnunet_config : dict, optional + The nnUNet configuration dictionary containing model parameters. Returns ------- @@ -335,12 +385,12 @@ def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = use_gaussian=True, use_mirroring=False, device=torch.device("cuda", 0), - verbose=False, - verbose_preprocessing=False, + verbose=True, + verbose_preprocessing=True, allow_tqdm=True, ) # initializes the network architecture, loads the checkpoint - wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) + wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name, dataset_json, plans, nnunet_config) return wrapper @@ -383,8 +433,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" ) - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + nnunet_checkpoint_final = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=False + ) + nnunet_checkpoint_best = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=False + ) nnunet_checkpoint = {} nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] @@ -470,7 +524,7 @@ def get_network_from_nnunet_plans( if model_ckpt is None: return network else: - state_dict = torch.load(model_ckpt) + state_dict = torch.load(model_ckpt, weights_only=False) network.load_state_dict(state_dict[model_key_in_ckpt]) return network @@ -534,7 +588,7 @@ def subfiles( Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) - nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=False) latest_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True ) @@ -545,7 +599,7 @@ def subfiles( epochs.sort() final_epoch: int = epochs[-1] monai_last_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=False ) best_checkpoints: list[str] = subfiles( @@ -558,10 +612,11 @@ def subfiles( key_metrics.sort() best_key_metric: str = key_metrics[-1] monai_best_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=False ) - nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"] + if "optimizer_state" in monai_last_checkpoint: + nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"] nnunet_checkpoint["network_weights"] = odict() @@ -577,7 +632,8 @@ def subfiles( nnunet_checkpoint["network_weights"] = odict() - nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"] + if "optimizer_state" in monai_last_checkpoint: + nnunet_checkpoint["optimizer_state"] = monai_best_checkpoint["optimizer_state"] for key in monai_best_checkpoint["network_weights"]: nnunet_checkpoint["network_weights"][key] = monai_best_checkpoint["network_weights"][key]