diff --git a/mammal/model.py b/mammal/model.py index 0dfed87..286ddde 100644 --- a/mammal/model.py +++ b/mammal/model.py @@ -1,3 +1,4 @@ +import copy import json import os from dataclasses import dataclass @@ -43,10 +44,29 @@ class MammalConfig(PretrainedConfig): random_weights: bool = False # If True, will not load the pre-trained weights @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "MammalConfig": + def from_dict( + cls, config_dict: dict[str, Any], *, allow_config_mismatch: bool = False + ) -> "MammalConfig": if "t5_config" not in config_dict: raise ValueError(f"config_dict should have key 't5_config'. {config_dict=}") + if allow_config_mismatch: + # Allowing to load the model even if the incoming config dict has unexpected key(s) + config_dict = copy.deepcopy( + config_dict + ) # We don't want to change the incoming dict + mismatch_keys = [] + for incoming_config_key in list(config_dict.keys()): + if incoming_config_key not in cls.__dataclass_fields__: + # Incoming key isn't part of the expected config keys + mismatch_keys.append(incoming_config_key) + config_dict.pop(incoming_config_key) + + if len(mismatch_keys) > 0: + print( + f"Warning, mismatch detected! Make sure you know what you are doing... {mismatch_keys=}" + ) + # We want to instantiate each class from it's dict (json), using the parent class logic # HF don't support the case where there are nested *different* configs. config_dict["t5_config"] = T5Config.from_dict(config_dict["t5_config"]) @@ -333,11 +353,6 @@ def _save_pretrained( save_directory: Path, save_config_only: bool = False, ) -> None: - """ - :param mode: either 'config', 'state_dict' or 'all' - :param metadata: metadata to store with the model - :param tokenizer_relative_path: relative path of the tokenizer to store with the model - """ print(f"Saving @ {save_directory}") # Define paths @@ -355,6 +370,7 @@ def from_pretrained( cls, pretrained_model_name_or_path: str | Path, *, + allow_config_mismatch: bool = False, config: MammalConfig | str | os.PathLike | None = None, config_overrides: dict[str, Any] | None = None, strict: bool = True, @@ -470,7 +486,9 @@ def from_pretrained( if isinstance(config, str): with open(config, encoding="utf-8") as f: config = json.load(f) - config = MammalConfig.from_dict(config) + config = MammalConfig.from_dict( + config, allow_config_mismatch=allow_config_mismatch + ) # override configuration if requested if config_overrides is not None: