Skip to content

Commit

Permalink
Fix backward compatibility issues in model checkpoint loading (#4199)
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw authored Feb 5, 2025
1 parent 3c949c5 commit 396be9c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)

if ckpt_label_info := checkpoint.get("label_info", None):
if isinstance(ckpt_label_info, LabelInfo) and not hasattr(ckpt_label_info, "label_ids"):
# NOTE: This is for backward compatibility
ckpt_label_info = LabelInfo(
label_groups=ckpt_label_info.label_groups,
label_names=ckpt_label_info.label_names,
label_ids=[str(i) for i in range(len(ckpt_label_info.label_names))],
)
self._label_info = ckpt_label_info

if ckpt_tile_config := checkpoint.get("tile_config", None):
Expand Down
6 changes: 6 additions & 0 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,12 @@ def export(
kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root)

model_cls = self.model.__class__
if hasattr(self.model, "model_name"):
# NOTE: This is a solution to fix backward compatibility issue.
# If the model has `model_name` attribute, it will be passed to the `load_from_checkpoint` method,
# making sure previous model trained without model_name can be loaded.
kwargs_user_input["model_name"] = self.model.model_name

self.model = model_cls.load_from_checkpoint(
checkpoint_path=checkpoint,
map_location="cpu",
Expand Down

0 comments on commit 396be9c

Please sign in to comment.