Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow config mismatches #24

Merged
merged 4 commits into from
Dec 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions mammal/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -43,10 +44,29 @@ class MammalConfig(PretrainedConfig):
random_weights: bool = False # If True, will not load the pre-trained weights

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "MammalConfig":
def from_dict(
cls, config_dict: dict[str, Any], *, allow_config_mismatch: bool = False
) -> "MammalConfig":
if "t5_config" not in config_dict:
raise ValueError(f"config_dict should have key 't5_config'. {config_dict=}")

if allow_config_mismatch:
# Allowing to load the model even if the incoming config dict has unexpected key(s)
config_dict = copy.deepcopy(
config_dict
) # We don't want to change the incoming dict
mismatch_keys = []
for incoming_config_key in list(config_dict.keys()):
if incoming_config_key not in cls.__dataclass_fields__:
# Incoming key isn't part of the expected config keys
mismatch_keys.append(incoming_config_key)
config_dict.pop(incoming_config_key)

if len(mismatch_keys) > 0:
print(
f"Warning, mismatch detected! Make sure you know what you are doing... {mismatch_keys=}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤣

)

# We want to instantiate each class from it's dict (json), using the parent class logic
# HF don't support the case where there are nested *different* configs.
config_dict["t5_config"] = T5Config.from_dict(config_dict["t5_config"])
Expand Down Expand Up @@ -333,11 +353,6 @@ def _save_pretrained(
save_directory: Path,
save_config_only: bool = False,
) -> None:
"""
:param mode: either 'config', 'state_dict' or 'all'
:param metadata: metadata to store with the model
:param tokenizer_relative_path: relative path of the tokenizer to store with the model
"""
print(f"Saving @ {save_directory}")

# Define paths
Expand All @@ -355,6 +370,7 @@ def from_pretrained(
cls,
pretrained_model_name_or_path: str | Path,
*,
allow_config_mismatch: bool = False,
config: MammalConfig | str | os.PathLike | None = None,
config_overrides: dict[str, Any] | None = None,
strict: bool = True,
Expand Down Expand Up @@ -470,7 +486,9 @@ def from_pretrained(
if isinstance(config, str):
with open(config, encoding="utf-8") as f:
config = json.load(f)
config = MammalConfig.from_dict(config)
config = MammalConfig.from_dict(
config, allow_config_mismatch=allow_config_mismatch
)

# override configuration if requested
if config_overrides is not None:
Expand Down