diff --git a/Megatron-LM b/Megatron-LM index cb6baf17..fe1f23cf 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit cb6baf171d064db6c2fee52f32dc1b51a2e6538d +Subproject commit fe1f23cf029d088c30a86989562b671af8967129 diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beae..76d8bfa3 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.layers.default.num_layers # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..089e6508 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -888,6 +888,7 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), + init=value.pop("init", base_class_field.init), repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 8bc86b73..73e11095 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -113,10 +113,12 @@ class TensorSpace: _is_setup: bool = False _distributed: "Distributed" - def __init__(self, distributed_config: DistributedConfig): + def __init__(self, distributed_config: DistributedConfig, _parent: "TensorSpace|None" = None): self._distributed_config = distributed_config self._tensor_dims: dict[str, TensorDim] = {} self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) + self._parent = _parent + self._sub_spaces: dict[str, TensorSpace] = {} def setup(self, distributed: "Distributed") -> None: assert distributed.config is self._distributed_config @@ -146,5 +148,17 @@ def add_tensor_dim(self, dim: TensorDim) -> None: Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name]) self._tensor_dims[dim.name] = dim + def add_sub_space(self, name: str) -> "TensorSpace": + self._sub_spaces[name] = TensorSpace(self._distributed_config, _parent=self) + return self._sub_spaces[name] + + def get_sub_space(self, name: str) -> "TensorSpace": + return self._sub_spaces[name] + def get_tensor_dim(self, name: str) -> TensorDim: - return self._tensor_dims[name] + if name in self._tensor_dims: + return self._tensor_dims[name] + elif self._parent is not None: + return self._parent.get_tensor_dim(name) + else: + raise KeyError(name) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..b4ba271f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -32,11 +32,7 @@ class LanguageModelKwargs: @config_class() class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): - transformer: TransformerArchitectureConfig = Field( - default_factory=TransformerArchitectureConfig, - desc="Configuration for the transformer architecture.", - hint=FieldHint.core, - ) + layers: TransformerArchitectureConfig = Field(default_factory=TransformerArchitectureConfig) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -60,11 +56,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled + self.use_position_embeddings = not self.layers.default.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) + self.layers.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions @@ -97,6 +93,17 @@ def from_flat_dict( cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") return super().from_flat_dict(default, strict) + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field(default, "transformer", ("layers", "default")) + return super()._from_dict(default, strict, flat) + @config_class() class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): @@ -111,7 +118,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig - transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + layers: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -175,14 +182,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.transformer.init_method_std is None: - self.transformer.init_method_std = self.transformer.hidden_size**-0.5 + if self.layers.default.init_method_std is None: + self.layers.default.init_method_std = self.layers.default.hidden_size**-0.5 if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std + self.init_method_std_embed = self.layers.default.init_method_std if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max + self.init_method_max_embed = self.layers.default.init_method_max if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min + self.init_method_min_embed = self.layers.default.init_method_min if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 67e7eb53..6a25e386 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -37,13 +37,13 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if config.layers.default.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout + self._dropout_p = config.layers.default.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4c03e393..f9b3dbf3 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -40,7 +40,7 @@ def __init__( tensor_space: TensorSpace, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer + self._debug_transformer = config.layers.default.debug_transformer self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space @@ -56,7 +56,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = config.layers.default.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor self._z_loss_factor = config.logit_z_loss diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index f64de9f1..752bd8eb 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -10,11 +10,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -69,14 +65,14 @@ class Attention(torch.nn.Module): def __init__( self, - config: TransformerConfig, + config: TransformerLayerConfig, tensor_space: TensorSpace, layer_index, ): super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + Assert.in_range(layer_index, 0, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer @@ -161,10 +157,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / (self._layer_index + 1), ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * (self._layer_index + 1) attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf985392..cc6d5a66 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -4,7 +4,16 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldUpdate, + check_field, + config_class, + process_field, + skip_valid_if_none, +) from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -156,7 +165,7 @@ class AddLinearBiasChoices(str, enum.Enum): @config_class() -class TransformerArchitectureConfig(BaseModelArchitectureConfig): +class TransformerLayerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False normalization: NormalizationArchitectureConfig = Field( default_factory=NormalizationArchitectureConfig, @@ -367,7 +376,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @config_class() -class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): +class TransformerLayerConfig(TransformerLayerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) # Default: hidden_size**-0.5 @@ -618,8 +627,133 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: DataType.bfloat16, ) - # Config parameter `window_size` only can be used with flash attention if not use_flash_attention: - Assert.is_(self.window_size, None) + assert self.max_window_layers is None return use_flash_attention + + +@config_class() +class RangeConfig(Config): + """ + A configuration that defines a range of values, to be used for example in python `slice` or `range`. + """ + + # TODO: Not specific to transformers, move elsewhere? + begin: int = Field( + default=0, + desc="The beginning of the range.", + hint=FieldHint.optional, + ) + end: int | None = Field( + default=None, + desc="The end of the range (excluded).", + hint=FieldHint.optional, + ) + step: int = Field( + default=1, + desc="The step for the range.", + hint=FieldHint.optional, + ) + + def in_range(self, index) -> bool: + """ + Checks whether `index` is in `range(begin, end, step)`. + """ + return ( + index >= self.begin and (self.end is None or index < self.end) and ((index - self.begin) % self.step == 0) + ) + + +def process_config_updates(updates: dict[str | tuple[str, ...], typing.Any]) -> dict[tuple[str, ...], typing.Any]: + return {(tuple(key.split("/")) if isinstance(key, str) else key): value for (key, value) in updates.items()} + + +@config_class() +class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + layer_ranges: list[RangeConfig] = Field( + default_factory=RangeConfig, + desc="Layer range.", + hint=FieldHint.core, + ) + updates: dict[tuple[str, ...], typing.Any] = Field( + default_factory=dict, valid=process_field(process_config_updates) + ) + config: TransformerLayerArchitectureConfig = Field(init=False) + _default: TransformerLayerArchitectureConfig = Field(init=False) + + def setup(self, default: TransformerLayerArchitectureConfig) -> None: + assert not hasattr(self, "_default") + self._default = default + + def _validate(self) -> None: + assert hasattr(self, "_default") + assert len(self.layer_ranges) > 0 + super()._validate() + # Create the full config from the default and updates. + # We use `default.from_dict` so we also have the appropriate class in `TransformerLayerRangeConfig`. + # For the architecture class we need to set `strict=False` because of possible non-architecture parameters. + self.config = self._default.from_dict(self._default, self.updates, strict=isinstance(self, BaseModelConfig)) + self.config.validate() + + def in_range(self, index) -> bool: + return any(layer_range.in_range(index) for layer_range in self.layer_ranges) + + +@config_class() +class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig): + config: TransformerLayerConfig = FieldUpdate() + _default: TransformerLayerConfig = FieldUpdate() + + +@config_class() +class TransformerArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + layers: list[TransformerLayerRangeArchitectureConfig] = Field(default_factory=list) + default: TransformerLayerArchitectureConfig = Field(default_factory=TransformerLayerArchitectureConfig) + + def _validate(self) -> None: + for layer in self.layers: + layer.setup(self.default) + super()._validate() + for layer in self.layers: + # Hidden layers must match + Assert.eq(layer.config.hidden_size, self.default.hidden_size) + # TODO: Move elsewhere? Kept here because used in a few places like default initialization. + Assert.eq(layer.config.num_layers, self.default.num_layers) + # TODO: Rotary preprocessor doesn't support variations across layers. + Assert.eq(layer.config.rotary.to_serialized(), self.default.rotary.to_serialized()) + + def get_layer_config_and_tensor_space( + self, index: int, tensor_space: TensorSpace + ) -> tuple[TransformerLayerArchitectureConfig, TensorSpace]: + for i, layer in enumerate(self.layers): + if layer.in_range(index): + return layer.config, tensor_space.get_sub_space(f"transformer_layers_{i}") + return self.default, tensor_space + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + assert self._validated + self.default.setup_tensor_space(tensor_space) + for i, layer in enumerate(self.layers): + layer.config.setup_tensor_space(tensor_space.add_sub_space(f"transformer_layers_{i}")) + + +@config_class() +class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): + layers: list[TransformerLayerRangeConfig] = FieldUpdate() + default: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig) + + def _validate(self) -> None: + super()._validate() + for layer in self.layers: + # Hidden layers must match + Assert.eq(layer.config.full_precision_residual, self.default.full_precision_residual) + if self.layers: + warnings.warn("Variable layer configuration is experimental. Use with caution.") + + def get_layer_config_and_tensor_space( + self, index: int, tensor_space: TensorSpace + ) -> tuple[TransformerLayerConfig, TensorSpace]: + return super().get_layer_config_and_tensor_space(index, tensor_space) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..c4405174 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -13,9 +13,9 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import ( RoutingType, - TransformerConfig, TransformerDimNames, TransformerKwargs, + TransformerLayerConfig, TransformerLossNames, ) from fast_llm.layers.transformer.mlp import MLPBase @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..5a42bdb9 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,13 +8,13 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerLayerConfig from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): super().__init__() self._name = name @@ -60,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index a509ce6a..aa2c33af 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> torch.Tensor: +def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tuple[torch.Tensor, float]: """ Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 """ @@ -40,7 +40,7 @@ def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tor return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0 -def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels, sequence_length) -> torch.Tensor: +def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels) -> tuple[torch.Tensor, float]: """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 @@ -49,7 +49,6 @@ def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_chann base = config.theta partial_rotary_factor = 1.0 dim = int(kv_channels * partial_rotary_factor) - max_position_embeddings = sequence_length factor = config.scale_factor attention_factor = config.attention_factor @@ -75,7 +74,6 @@ def linear_ramp_factor(min, max, dim): ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs # to expand the possible context length. In other words, interpolation = apply scaling factor. # pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim) @@ -99,7 +97,6 @@ def linear_ramp_factor(min, max, dim): return inv_freq, attention_factor - def get_rotary_frequencies( config: RotaryConfig, sequence_length, @@ -118,7 +115,7 @@ def get_rotary_frequencies( if config.type == RotaryEmbeddingType.llama3: frequencies, attention_scaling = apply_llama3_scaling(config, frequencies) elif config.type == RotaryEmbeddingType.yarn: - frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels, sequence_length) + frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels) else: attention_scaling = 1.0 angles = torch.outer(positions, frequencies) @@ -205,13 +202,23 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - assert not self._config.do_use_flash_attention(self._distributed_config) + all_configs = [config.default] + [layer.config for layer in config.layers] + self._enabled = not all( + layer_config.do_use_flash_attention(self._distributed_config) for layer_config in all_configs + ) + if self._enabled: + window_sizes = {layer_config.window_size for layer_config in all_configs} + if len(window_sizes) != 1: + raise ValueError("Variable window size not supported for backup attention.") + self._window_size = window_sizes.pop() + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) def create_tensors(self, sequence_length: int) -> None: + if not self._enabled: + return if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -221,8 +228,8 @@ def create_tensors(self, sequence_length: int) -> None: dtype=torch.bool, device=self._tensor_space.distributed.device, ).tril_() - if self._config.window_size is not None: - self._mask.triu_(-self._config.window_size + 1) + if self._window_size is not None: + self._mask.triu_(-self._window_size + 1) self._mask_value = torch.full( [], torch.finfo(self._distributed_config.training_dtype.torch).min, @@ -231,6 +238,8 @@ def create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + if not self._enabled: + return sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size kwargs[TransformerKwargs.attention_mask] = self._mask[ None, None, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k, None, :sequence_k @@ -238,6 +247,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + if not self._enabled: + return kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 4780dd3a..bd351358 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLayerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -24,7 +24,7 @@ class TransformerLayer(Layer): def __init__( self, - config: TransformerConfig, + config: TransformerLayerConfig, tensor_space: TensorSpace, layer_index: int, ): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 51c8a3b7..ba53dd40 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,14 +24,14 @@ from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, - Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.model import GPTModel @@ -52,16 +52,16 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + query = convert_rotary_complex_to_real(query[:], self._config.layers.default.kv_channels, 0) return (query,) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.transformer.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + query = convert_rotary_real_to_complex(query[:], self._config.layers.default.kv_channels, 0) return (query,) @@ -74,16 +74,16 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.transformer.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + key = convert_rotary_complex_to_real(key, self._config.layers.default.kv_channels, 0) return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.transformer.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.transformer.kv_channels, 0) + if self._config.layers.default.rotary.complex_format: + key = convert_rotary_real_to_complex(key[:], self._config.layers.default.kv_channels, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) @@ -116,34 +116,37 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + # Variable layer config not supported. + # TODO: Find a way to support variable non-architecture parameters. + ConstantImportParamConverter(fast_llm_names=(("layers", "layers"),), fast_llm_value=[]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + fast_llm_names=(("layers", "default", "rotary", "theta"),), export_names=(("rope_theta",),) ), MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), + fast_llm_names=(("layers", "default", "activation_type"),), export_names=(("hidden_act",),), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, ), RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), + fast_llm_names=(("layers", "default", "num_layers"),), export_names=(("num_hidden_layers",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), + fast_llm_names=(("layers", "default", "hidden_size"),), export_names=(("hidden_size",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), + fast_llm_names=(("layers", "default", "num_attention_heads"),), export_names=(("num_attention_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), + fast_llm_names=(("layers", "default", "head_groups"),), export_names=(("num_key_value_heads",),), ), RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), + fast_llm_names=(("layers", "default", "ffn_hidden_size"),), export_names=(("intermediate_size",),), ), RenameParamConverter( @@ -160,14 +163,15 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: pass - def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - transformer_config: TransformerConfig = self._model.config.base_model.transformer + num_layers = self._model.config.base_model.layers.default.num_layers + norm_bias: bool = ( + self._model.config.base_model.layers.default.normalization.type == NormalizationType.layer_norm + ) + layer_config = self._model.config.base_model.layers.default # Embedding and output if self._model.config.base_model.tie_word_embeddings: @@ -187,19 +191,19 @@ def _create_weight_converters( converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.query", f"model.layers.{i}.self_attn.q_proj", - transformer_config.add_attn_qkv_bias, + layer_config.add_attn_qkv_bias, QueryWeightConverter, ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.key_value", (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), - transformer_config.add_attn_qkv_bias, + layer_config.add_attn_qkv_bias, KeyValueWeightConverter, ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", - transformer_config.add_attn_dense_bias, + layer_config.add_attn_dense_bias, ) # Norm @@ -252,28 +256,31 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + fast_llm_names=(("layers", "default", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.layer_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=False), + ConstantImportParamConverter( + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value=True ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", layer_config.add_mlp_bias ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -284,27 +291,30 @@ class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + fast_llm_names=(("layers", "default", "kv_channels"),), + export_names=(("head_dim",),), + ), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value=False ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), RopeScalingParamConverter( fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + ("layers", "default", "rotary", "type"), + ("layers", "default", "rotary", "scale_factor"), + ("layers", "default", "rotary", "low_frequency_factor"), + ("layers", "default", "rotary", "high_frequency_factor"), + ("layers", "default", "rotary", "original_context_length"), + ("layers", "default", "rotary", "attention_factor"), + ("layers", "default", "rotary", "beta_fast"), + ("layers", "default", "rotary", "beta_slow"), ), export_names=(("rope_scaling",),), ), @@ -364,18 +374,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -411,25 +421,26 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("layers", "default", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + fast_llm_names=(("layers", "default", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("layers", "default", "gated"),), fast_llm_value=True), ConstantImportParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" + fast_llm_names=(("layers", "default", "add_linear_biases"),), fast_llm_value="only_attn_qkv" ), RopeScalingParamConverter( fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + ("layers", "default", "rotary", "type"), + ("layers", "default", "rotary", "scale_factor"), + ("layers", "default", "rotary", "low_frequency_factor"), + ("layers", "default", "rotary", "high_frequency_factor"), + ("layers", "default", "rotary", "original_context_length"), + ("layers", "default", "rotary", "attention_factor"), + ("layers", "default", "rotary", "beta_fast"), + ("layers", "default", "rotary", "beta_slow"), ), export_names=(("rope_scaling",),), ), @@ -437,18 +448,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + layer_config = self._model.config.base_model.layers.default return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + layer_config.add_mlp_bias, MLPLayer2Converter, ), ] @@ -486,19 +497,20 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk + fast_llm_names=(("layers", "default", "expert_routing_type"),), fast_llm_value=RoutingType.topk ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts"),), export_names=(("num_local_experts",),) + fast_llm_names=(("layers", "default", "num_experts"),), export_names=(("num_local_experts",),) ), RenameParamConverter( - fast_llm_names=(("transformer", "num_experts_per_token"),), export_names=(("num_experts_per_tok",),) + fast_llm_names=(("layers", "default", "num_experts_per_token"),), + export_names=(("num_experts_per_tok",),), ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - num_experts = self._model.config.base_model.transformer.num_experts + num_experts = self._model.config.base_model.layers.default.num_experts return [ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), SplitWeightConverter( diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 842a064e..975fd0c7 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import TransformerLayerConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -12,7 +12,7 @@ def get_init_megatron( - meta: "ParameterMeta", config: TransformerConfig + meta: "ParameterMeta", config: TransformerLayerConfig ) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): Assert.eq(distributed.config.world_size, 1) @@ -49,7 +49,7 @@ def set_megatron_distributed_seeds(config: "DistributedConfig") -> None: def _init_attention_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": # Megatron combines q and kv and inverts the initialization order of qkv and dense layers. # It also always treats the tensors as tensor-parallel and uses a different rotary embedding format. @@ -114,7 +114,7 @@ def _init_position_embeddings_megatron( def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch @@ -138,7 +138,7 @@ def _init_moe_router_megatron( def _init_moe_mlp_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + config: TransformerLayerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": assert meta.param_init_method is not None generator = distributed.tp_init_generator if meta.is_tensor_parallel else distributed.pp_init_generator diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8aa68333..f522ebb8 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -49,32 +49,27 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + param.init_parameter = get_init_megatron(param, self._config.layers.default) # Noqa if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - self._config.transformer.rotary, self._tensor_space - ) - if not self._use_flash_attention: - self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.transformer, self._tensor_space + self._config.layers.default.rotary, self._tensor_space ) + self._backup_attention_preprocessor = BackupAttentionPreprocessor(self._config.layers, self._tensor_space) def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( - self._config.transformer, - self._tensor_space, - layer_index=i + 1, + *self._config.layers.get_layer_config_and_tensor_space(layer_index, self._tensor_space), + layer_index=layer_index, ) - for i in range(self._config.transformer.num_layers) + for layer_index in range(self._config.layers.default.num_layers) ], LanguageModelHead(self._config, self._tensor_space), ] @@ -175,10 +170,9 @@ def preprocess_meta( ) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess_meta(kwargs) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) + self._backup_attention_preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -214,10 +208,9 @@ def preprocess( if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.create_tensors(sequence_length) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.create_tensors(sequence_length) - if not self._use_flash_attention: - self._backup_attention_preprocessor.create_tensors(sequence_length) + self._backup_attention_preprocessor.create_tensors(sequence_length) preprocessed = [] presents = None @@ -257,10 +250,9 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) - if self._config.transformer.rotary.enabled: + if self._config.layers.default.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess(kwargs) + self._backup_attention_preprocessor.preprocess(kwargs) preprocessed.append((tokens, kwargs)) return preprocessed @@ -290,22 +282,22 @@ def loss_defs(self) -> list[LossDef]: LossDef(name=LanguageModelLossNames.language_model_loss, formatted_name="language model loss", count=1) ] if ( - self._config.transformer.num_experts > 1 - and self._config.transformer.expert_routing_type == RoutingType.topk + self._config.layers.default.num_experts > 1 + and self._config.layers.default.expert_routing_type == RoutingType.topk ): loss_defs.append( LossDef( name=TransformerLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.layers.default.num_layers, ) ) - if self._config.transformer.expert_z_loss_coefficient: + if self._config.layers.default.expert_z_loss_coefficient: loss_defs.append( LossDef( name=TransformerLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.layers.default.num_layers, ) ) if self._config.logit_z_loss: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 7b03a7b4..21e08229 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -26,7 +26,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, # TODO: Do in model, automate/generalize, get other stats """Get tflop/s/GPU from global-batch-size and elapsed-time""" checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.model.base_model.transformer + transformer_config = self._config.model.base_model.layers.default sequence_length = self._config.batch.sequence_length tokens = self._config.batch.batch_size * sequence_length diff --git a/tests/test_attention.py b/tests/test_attention.py index c8b91d76..0d90246e 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,8 +1,9 @@ import unittest.mock -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig + from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.attention import Attention +from fast_llm.layers.transformer.config import TransformerLayerConfig def test_decide_window_size(): @@ -10,23 +11,23 @@ def test_decide_window_size(): attention._decide_window_size = Attention._decide_window_size.__get__(attention) # Attach real method # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=2) attention._layer_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) - attention._config = TransformerConfig(window_size=512, max_window_layers=2) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=2) attention._layer_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) - attention._config = TransformerConfig(window_size=512, max_window_layers=None) + attention._config = TransformerLayerConfig(window_size=512, max_window_layers=None) assert attention._decide_window_size() == 512 def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, + transformer_conf = TransformerLayerConfig( + num_layers=2, num_attention_heads=2, hidden_size=16, ) @@ -35,4 +36,3 @@ def test_attention_constructor(): transformer_conf.setup_tensor_space(tensor_space) Attention(transformer_conf, tensor_space, 1) - diff --git a/tests/test_config.py b/tests/test_config.py index 7141812a..20958906 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,19 +1,11 @@ import pathlib -import pytest import subprocess -import unittest.mock -import yaml +import pytest +import yaml -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerArchitectureConfig, - AddLinearBiasChoices, -) -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.config import ValidationError - +from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerLayerArchitectureConfig from fast_llm.models.auto import trainer_registry @@ -64,78 +56,73 @@ def test_validate_example_config(): trainer_registry["gpt"].from_dict(fast_llm_config_dict) -def test_do_use_flash_attention(): - # Create a mock DistributedConfig - mock_distributed_config = unittest.mock.Mock(spec=DistributedConfig) - - # Test case 1: use_flash_attention is True and training_dtype is float16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is True - - # Test case 2: use_flash_attention is False - config = TransformerConfig(use_flash_attention=False, window_size=None) - mock_distributed_config.training_dtype = DataType.float16 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 3: use_flash_attention is True but training_dtype is not float16 or bfloat16 - config = TransformerConfig(use_flash_attention=True, window_size=None) - mock_distributed_config.training_dtype = DataType.float32 - assert config.do_use_flash_attention(mock_distributed_config) is False - - # Test case 4: use_flash_attention is False and window_size is not None - config = TransformerConfig(use_flash_attention=False, window_size=512) - mock_distributed_config.training_dtype = DataType.float32 - with pytest.raises(AssertionError): - config.do_use_flash_attention(mock_distributed_config) - - def test_add_linear_biases_valid_values(): # Valid boolean values - assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_linear_biases is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_linear_biases is False # Valid enum values - assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere assert ( - TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases + TransformerLayerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases + == AddLinearBiasChoices.nowhere + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases == AddLinearBiasChoices.everywhere ) assert ( - TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv + TransformerLayerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases + == AddLinearBiasChoices.only_attn_qkv ) def test_add_linear_biases_invalid_values(): with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases="invalid_value") + TransformerLayerArchitectureConfig(add_linear_biases="invalid_value") with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=123) + TransformerLayerArchitectureConfig(add_linear_biases=123) with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=None) + TransformerLayerArchitectureConfig(add_linear_biases=None) def test_add_mlp_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_mlp_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_mlp_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_mlp_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_mlp_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_mlp_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_mlp_bias is False + ) def test_add_attn_qkv_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_qkv_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_attn_qkv_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias + is True + ) def test_add_attn_dense_bias(): - assert TransformerArchitectureConfig(add_linear_biases=True).add_attn_dense_bias is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False + assert TransformerLayerArchitectureConfig(add_linear_biases=True).add_attn_dense_bias is True + assert TransformerLayerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias + is True + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False + ) + assert ( + TransformerLayerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias + is False + ) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index bcfbaf69..7fea9ba5 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,12 +1,12 @@ -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.config import TransformerLayerConfig +from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.transformer.mlp import MLP def test_mlp_constructor(): - transformer_conf = TransformerConfig( + transformer_conf = TransformerLayerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, @@ -19,12 +19,8 @@ def test_mlp_constructor(): def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - num_experts=2, - add_linear_biases=False + transformer_conf = TransformerLayerConfig( + num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False ) distributed_config = DistributedConfig() tensor_space = TensorSpace(distributed_config=distributed_config) diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..93287c7a --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,56 @@ +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.common.config import NormalizationType +from fast_llm.models.gpt.config import GPTBaseModelConfig +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.utils import Assert + + +def test_variable_window_size(): + model = GPTBaseModel( + GPTBaseModelConfig.from_dict( + { + "layers": { + "default": {"window_size": 1024, "num_layers": 8, "normalization": {"type": "rms_norm"}}, + "layers": [ + { + # Layers 5, 6 and 7 + "layer_ranges": [{"begin": 5, "end": None}], + # Update normalization epsilon, keep rms norm. + "updates": {"window_size": None, "normalization/epsilon": 1}, + }, + { + # Layers 0, 3 and 5, but 5 already covered above so excluded. + "layer_ranges": [{"begin": 0, "end": 1}, {"begin": 3, "end": 6, "step": 2}], + # Override the whole normalization, type reverts back to default (layer_norm) + "updates": {"window_size": 512, "ffn_hidden_size": 64, "normalization": {"epsilon": 1}}, + }, + ], + } + } + ), + DistributedConfig(training_dtype=DataType.bfloat16), + ) + Assert.eq( + [layer._config.window_size for layer in model.layers[1:-1]], [512, 1024, 1024, 512, 1024, None, None, None] + ) + Assert.eq( + [layer._config.normalization.type for layer in model.layers[1:-1]], + [ + NormalizationType.layer_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.layer_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + NormalizationType.rms_norm, + ], + ) + Assert.eq([layer._config.normalization.epsilon for layer in model.layers[1:-1]], [1, 1e-5, 1e-5, 1, 1e-5, 1, 1, 1]) + Assert.eq( + [layer._config.ffn_hidden_size for layer in model.layers[1:-1]], [64, 4096, 4096, 64, 4096, 4096, 4096, 4096] + ) + # Non-architecture parameters (`window_size`) need to be ignored when converting to architecture config. + # (See `TransformerLayerRangeArchitectureConfig.setup`.) + model.config.get_architecture()