From 1a04b8939e3b081ad647f4850d94f6be967f7a00 Mon Sep 17 00:00:00 2001 From: openhands Date: Wed, 24 Jun 2026 21:19:37 +0000 Subject: [PATCH] Fix CheckpointLoader bug with strict_shape=False and multiple objects - Only apply copy_model_state for torch.nn.Module objects to avoid errors when loading non-iterable objects like optimizers - Update checkpoint with model's state_dict after copy_model_state to ensure Checkpoint.load_objects can load the remaining objects correctly --- monai/handlers/checkpoint_loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index ffc59e7732..810f233594 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -99,8 +99,14 @@ def __call__(self, engine: Engine) -> None: checkpoint = {k: checkpoint} # skip items that don't match data shape + # only apply copy_model_state for torch.nn.Module objects + # this is needed to handle shape mismatches for models for k, obj in self.load_dict.items(): - checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + if isinstance(obj, torch.nn.Module) and k in checkpoint: + _, _, _ = copy_model_state(obj, checkpoint[k], inplace=True) + # after copy_model_state, the model's state_dict has been updated + # with matching shapes, so we need to update the checkpoint + checkpoint[k] = obj.state_dict() # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs