diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..a5e4803f 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -125,6 +125,8 @@ def __init__( # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + # Option to skip (postpone) instantiation of a `Config` field. + auto_instantiate: bool = True, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, @@ -152,6 +154,7 @@ def __init__( self.doc = doc self.hint = hint self.valid = valid + self.auto_instantiate = auto_instantiate class FieldUpdate(dict): @@ -265,6 +268,10 @@ def wrap(cls): return wrap(cls) +# A marker to prevent auto instantiation of a config. +NoAutoInstantiate = object() + + @dataclasses.dataclass() class Config: """ @@ -712,10 +719,16 @@ def _from_dict( continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): - if flat: - out_arg_dict[name] = field.type._from_dict(default, False, True) + assert isinstance(field.default_factory, type) and issubclass( + field.default_factory, field.type + ) + if field.auto_instantiate: + if flat: + out_arg_dict[name] = field.default_factory._from_dict(default, False, True) + else: + out_arg_dict[name] = field.default_factory._from_dict(default.pop(name, {}), strict) else: - out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict) + out_arg_dict[name] = default.pop(name, {}) elif name in default: out_arg_dict[name] = default.pop(name) else: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d6997105..a816467c 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -248,44 +248,22 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: def from_pretrained( cls, pretrained: CheckpointLoadMetadataConfig, - default: typing.Self | None = None, + *updates: dict[str | tuple[str, ...], typing.Any] | None, ) -> typing.Self: - # TODO: Add *updates? assert pretrained.path is not None - metadata = cls.load_metadata(pretrained) - return cls.from_metadata(pretrained, metadata, default) + return cls.from_metadata(cls.load_metadata(pretrained), *updates) @classmethod def from_metadata( cls, - pretrained: CheckpointLoadMetadataConfig, metadata: "CheckpointMetadata", - default: typing.Self | None = None, - updates: dict[str | tuple[str, ...], typing.Any] | None = None, + *updates: dict[str | tuple[str, ...], typing.Any] | None, ) -> typing.Self: # TODO: Standardize to *updates? # TODO v0.3: Update, remove support for older checkpoints. if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2): raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}") - pretrained_config = cls.from_dict(metadata.config) - if not pretrained.load_config.load_architecture: - assert default is not None - config = default.to_copy() - config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_config.load_fast_llm: - config = pretrained_config - else: - with NoAutoValidate(): - config = cls() if default is None else default.to_copy() - if pretrained.load_config.load_base_model: - config.base_model = pretrained_config.base_model - else: - config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) - config.validate() - - if updates: - config = config.to_copy(updates) - return config + return cls.from_dict(metadata.config, *updates) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -315,7 +293,10 @@ class PretrainedFastLLMModelConfig(Config): _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core + default_factory=FastLLMModelConfig, + desc="Configuration for the Fast-LLM model.", + hint=FieldHint.core, + auto_instantiate=False, ) pretrained: CheckpointLoadConfig = Field( default_factory=CheckpointLoadConfig, @@ -327,8 +308,10 @@ def _validate(self) -> None: assert self.model is not None self.pretrained.setup(self.model) self.pretrained.validate() - if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + if self.pretrained.path is None: + self.model = self.get_field("model").default_factory.from_dict(self.model) + else: + self.model = self.model.from_pretrained(self.pretrained, self.model) self._setup() super()._validate() diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index b86722f8..1daaf2c8 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -55,7 +55,7 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class()