Skip to content

[Prototype] Make the model config override the pretrained config #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
# Should raise an Exception in case of failure, and return the validated value.
# Run before the default validation (type check).
valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
# Option to skip (postpone) instantiation of a `Config` field.
auto_instantiate: bool = True,
default=dataclasses.MISSING,
default_factory=dataclasses.MISSING,
init: bool = True,
Expand Down Expand Up @@ -152,6 +154,7 @@ def __init__(
self.doc = doc
self.hint = hint
self.valid = valid
self.auto_instantiate = auto_instantiate


class FieldUpdate(dict):
Expand Down Expand Up @@ -265,6 +268,10 @@ def wrap(cls):
return wrap(cls)


# A marker to prevent auto instantiation of a config.
NoAutoInstantiate = object()


@dataclasses.dataclass()
class Config:
"""
Expand Down Expand Up @@ -712,10 +719,16 @@ def _from_dict(
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
if flat:
out_arg_dict[name] = field.type._from_dict(default, False, True)
assert isinstance(field.default_factory, type) and issubclass(
field.default_factory, field.type
)
if field.auto_instantiate:
if flat:
out_arg_dict[name] = field.default_factory._from_dict(default, False, True)
else:
out_arg_dict[name] = field.default_factory._from_dict(default.pop(name, {}), strict)
else:
out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict)
out_arg_dict[name] = default.pop(name, {})
elif name in default:
out_arg_dict[name] = default.pop(name)
else:
Expand Down
41 changes: 12 additions & 29 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,44 +248,22 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]:
def from_pretrained(
cls,
pretrained: CheckpointLoadMetadataConfig,
default: typing.Self | None = None,
*updates: dict[str | tuple[str, ...], typing.Any] | None,
) -> typing.Self:
# TODO: Add *updates?
assert pretrained.path is not None
metadata = cls.load_metadata(pretrained)
return cls.from_metadata(pretrained, metadata, default)
return cls.from_metadata(cls.load_metadata(pretrained), *updates)

@classmethod
def from_metadata(
cls,
pretrained: CheckpointLoadMetadataConfig,
metadata: "CheckpointMetadata",
default: typing.Self | None = None,
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
*updates: dict[str | tuple[str, ...], typing.Any] | None,
) -> typing.Self:
# TODO: Standardize to *updates?
# TODO v0.3: Update, remove support for older checkpoints.
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2):
raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}")
pretrained_config = cls.from_dict(metadata.config)
if not pretrained.load_config.load_architecture:
assert default is not None
config = default.to_copy()
config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn)
elif pretrained.load_config.load_fast_llm:
config = pretrained_config
else:
with NoAutoValidate():
config = cls() if default is None else default.to_copy()
if pretrained.load_config.load_base_model:
config.base_model = pretrained_config.base_model
else:
config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture())
config.validate()

if updates:
config = config.to_copy(updates)
return config
return cls.from_dict(metadata.config, *updates)

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
Expand Down Expand Up @@ -315,7 +293,10 @@ class PretrainedFastLLMModelConfig(Config):
_abstract = True
# This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before.
model: FastLLMModelConfig = Field(
default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core
default_factory=FastLLMModelConfig,
desc="Configuration for the Fast-LLM model.",
hint=FieldHint.core,
auto_instantiate=False,
)
pretrained: CheckpointLoadConfig = Field(
default_factory=CheckpointLoadConfig,
Expand All @@ -327,8 +308,10 @@ def _validate(self) -> None:
assert self.model is not None
self.pretrained.setup(self.model)
self.pretrained.validate()
if self.pretrained.path is not None:
self.model = self.model.from_pretrained(self.pretrained, default=self.model)
if self.pretrained.path is None:
self.model = self.get_field("model").default_factory.from_dict(self.model)
else:
self.model = self.model.from_pretrained(self.pretrained, self.model)
self._setup()
super()._validate()

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/custom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM"

@config_class()
class PretrainedCustomModelConfig(PretrainedGPTModelConfig):
model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig)
model: CustomModelConfig = FieldUpdate()


@config_class()
Expand Down
Loading