Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,14 @@ def __call__(self, engine: Engine) -> None:
checkpoint = {k: checkpoint}

# skip items that don't match data shape
# only apply copy_model_state for torch.nn.Module objects
# this is needed to handle shape mismatches for models
for k, obj in self.load_dict.items():
checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]
if isinstance(obj, torch.nn.Module) and k in checkpoint:
_, _, _ = copy_model_state(obj, checkpoint[k], inplace=True)
# after copy_model_state, the model's state_dict has been updated
# with matching shapes, so we need to update the checkpoint
checkpoint[k] = obj.state_dict()

# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
prior_max_epochs = engine.state.max_epochs
Expand Down