diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index ffc59e7732..810f233594 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -99,8 +99,14 @@ def __call__(self, engine: Engine) -> None: checkpoint = {k: checkpoint} # skip items that don't match data shape + # only apply copy_model_state for torch.nn.Module objects + # this is needed to handle shape mismatches for models for k, obj in self.load_dict.items(): - checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + if isinstance(obj, torch.nn.Module) and k in checkpoint: + _, _, _ = copy_model_state(obj, checkpoint[k], inplace=True) + # after copy_model_state, the model's state_dict has been updated + # with matching shapes, so we need to update the checkpoint + checkpoint[k] = obj.state_dict() # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs