Skip to content

Extend add_linear_biases to support a dictionary of sub-layers to which linear bias should be added. #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fast_llm.functional.rotary import apply_rotary_embeddings
from fast_llm.functional.triton.rotary import triton_rotary_autograd_
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs
from fast_llm.layers.transformer.config import (
TransformerConfig,
TransformerDimNames,
TransformerKwargs,
TransformerSubLayerKeys,
)
from fast_llm.logging import log_distributed_grad, log_distributed_tensor
from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_
from fast_llm.utils import Assert
Expand Down Expand Up @@ -102,7 +107,7 @@ def __init__(
self.query = OutputParallelLinear(
hidden_dim,
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query),
bias=self._config.add_linear_biases,
bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_query),
weight_init_method=init_method_qkv,
bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand All @@ -111,7 +116,7 @@ def __init__(
self.key_value = OutputParallelLinear(
hidden_dim,
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value),
bias=self._config.add_linear_biases,
bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_key_value),
weight_init_method=init_method_qkv,
bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand All @@ -123,7 +128,7 @@ def __init__(
self.dense = InputParallelLinear(
self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense),
hidden_dim,
bias=self._config.add_linear_biases,
bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_dense),
weight_init_method=init_method_std_attn_proj,
bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_,
sequence_parallel=self._sequence_parallel,
Expand Down
76 changes: 75 additions & 1 deletion fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
import itertools
import logging
import math
import re
import typing
import warnings

Expand Down Expand Up @@ -149,6 +151,14 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig):
pass


class TransformerSubLayerKeys(str, enum.Enum):
attn_query = "layers.self_attn.query"
attn_key_value = "layers.self_attn.key_value"
attn_dense = "layers.self_attn.dense"
mlp_layer1 = "layers.mlp.layer_1"
mlp_layer2 = "layers.mlp.layer_2"


@config_class()
class TransformerArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
Expand All @@ -174,7 +184,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
add_linear_biases: bool = Field(default=True, desc="Add biases to all dense layers.", hint=FieldHint.core)
add_linear_biases: bool | dict[TransformerSubLayerKeys, str] = Field(
default=True,
desc="Add biases to all or selected dense layers. Accepted values: True, False, or a dict with keys from TransformerSubLayerKeys and values as '*' or index ranges.",
hint=FieldHint.core,
)
ffn_hidden_size: int = Field(
default=None,
desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.",
Expand Down Expand Up @@ -234,6 +248,10 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig):
hint=FieldHint.feature,
)

_parsed_add_linear_biases: bool | dict[TransformerSubLayerKeys, set[int] | str] = Field(
default=None, init=False, repr=False
)

def _validate(self) -> None:
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
Expand All @@ -243,14 +261,70 @@ def _validate(self) -> None:
self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu
self.projection_size = self.num_attention_heads * self.kv_channels
self.num_unshared_experts = self.num_experts - self.num_shared_experts

# Validate before parent validate to have assertion error on invalid key for TransformerSubLayerKeys
self._validate_add_linear_biases()
self._parse_add_linear_biases()

super()._validate()

if not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")

Assert.leq(self.num_shared_experts, self.num_experts)
Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts)
Assert.multiple(self.num_attention_heads, self.head_groups)

def _validate_add_linear_biases(self) -> None:
"""Validate the `add_linear_biases` parameter."""
if isinstance(self.add_linear_biases, dict):
Assert.gt(len(self.add_linear_biases), 0)
for key, value in self.add_linear_biases.items():
Assert.incl(key, TransformerSubLayerKeys) # Assert valid sublayer key
Assert.custom(
lambda val: val == "*" or re.match(r"^\d+(:\d+(:\d+)?)?(,\s*\d+(:\d+(:\d+)?)?)*$", val),
value,
) # Assert valid range format

def _parse_add_linear_biases(self) -> bool | dict[TransformerSubLayerKeys, set[int] | str]:
"""Parse `add_linear_biases` and store the result for quick lookup."""
if isinstance(self.add_linear_biases, bool):
self._parsed_add_linear_biases = self.add_linear_biases
return

parsed = {}
for key, value in self.add_linear_biases.items():
parsed[key] = self._parse_indices(value)
self._parsed_add_linear_biases = parsed

def _parse_indices(self, indices_str: str) -> set[int]:
"""Parse layer indices from a string like '1:10:2, 20, 30' or '*'."""
indices = []
# Layers are numbered from 1 as 0 layer is embedding layer in Fast-LLM
if indices_str == "*":
indices.extend(range(1, self.num_layers + 1))
else:
for part in indices_str.split(","):
part = part.strip()
if ":" in part:
parts = list(map(int, part.split(":")))
start, stop = parts[0] + 1, parts[1] + 1
step = parts[2] if len(parts) == 3 else 1
indices.extend(range(start, stop, step))
else:
indices.append(int(part) + 1)
return set(itertools.chain(indices))

def should_add_linear_bias(self, layer_index: int, sublayer_key: TransformerSubLayerKeys) -> bool:
"""Check if linear bias should be added for a given layer and sublayer."""
if isinstance(self._parsed_add_linear_biases, bool):
return self._parsed_add_linear_biases

if sublayer_key in self._parsed_add_linear_biases:
return layer_index in self._parsed_add_linear_biases[sublayer_key]

return False

@classmethod
def _from_dict(
cls,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase):

_group: ProcessGroup

def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"):
Assert.gt(config.num_experts, 1)
# TODO: Implement?
assert not config.add_linear_biases, "Biases not supported for MoE."
super().__init__(config, tensor_space, name)
super().__init__(config, tensor_space, layer_index, name)
self._config = config
self._tensor_space = tensor_space
self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory
Expand Down
13 changes: 7 additions & 6 deletions fast_llm/layers/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd
from fast_llm.layers.common.linear import LinearBase
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames
from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerKeys
from fast_llm.tensor import init_normal_, init_zeros_
from fast_llm.utils import Assert


class MLPBase(Layer, ABC):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"):
super().__init__()
self._name = name
self._layer_index = layer_index

init_method_1 = init_normal_(
std=config.init_method_std_mlp_1,
Expand All @@ -42,15 +43,15 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
self.layer_1 = LinearBase(
hidden_dim,
tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp),
bias=config.add_linear_biases,
bias=config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.mlp_layer1),
weight_init_method=init_method_1,
bias_init_method=init_method_1 if config.random_bias_init else init_zeros_,
lr_scale=tuple(config.mlp_lr_scale),
)
self.layer_2 = LinearBase(
self._intermediate_dim,
hidden_dim,
bias=config.add_linear_biases,
bias=config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.mlp_layer2),
weight_init_method=init_method_2,
bias_init_method=init_method_2 if config.random_bias_init else init_zeros_,
auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1,
Expand All @@ -60,9 +61,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s


class MLP(MLPBase):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"):
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"):
Assert.eq(config.num_experts, 1)
super().__init__(config, tensor_space, name)
super().__init__(config, tensor_space, layer_index, name)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.self_attn = Attention(self._config, self._tensor_space, layer_index)

self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)(
self._config, self._tensor_space, f"{self.name} mlp"
self._config, self._tensor_space, self._layer_index, f"{self.name} mlp"
)

@torch.compile
Expand Down
16 changes: 16 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest.mock
from fast_llm.layers.transformer.attention import Attention
from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace


def test_decide_window_size():
Expand All @@ -20,3 +22,17 @@ def test_decide_window_size():
# Arrange - Case 3: max_window_layers is None (always return window_size)
attention._config = TransformerConfig(window_size=512, max_window_layers=None)
assert attention._decide_window_size() == 512


def test_attention_constructor():
transformer_conf = TransformerConfig(
num_layers=2,
num_attention_heads=2,
hidden_size=16,
)
distributed_config = DistributedConfig()
tensor_space = TensorSpace(distributed_config=distributed_config)
transformer_conf.setup_tensor_space(tensor_space)

Attention(transformer_conf, tensor_space, 1)

99 changes: 98 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import yaml


from fast_llm.layers.transformer.config import TransformerConfig
from fast_llm.layers.transformer.config import (
TransformerConfig,
TransformerArchitectureConfig,
TransformerSubLayerKeys,
)
from fast_llm.utils import Assert
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.config_utils.data_type import DataType
Expand Down Expand Up @@ -84,3 +88,96 @@ def test_do_use_flash_attention():
mock_distributed_config.training_dtype = DataType.float32
with pytest.raises(AssertionError):
config.do_use_flash_attention(mock_distributed_config)


@pytest.fixture
def config_with_true_biases():
"""Fixture for TransformerArchitectureConfig with True add_linear_biases."""
return TransformerArchitectureConfig(add_linear_biases=True)


@pytest.fixture
def config_with_false_biases():
"""Fixture for TransformerArchitectureConfig with False add_linear_biases."""
return TransformerArchitectureConfig(add_linear_biases=False)


@pytest.fixture
def config_with_dict_biases():
"""Fixture for TransformerArchitectureConfig with dict add_linear_biases."""
return TransformerArchitectureConfig(
num_layers = 10,
add_linear_biases={
"layers.self_attn.query": "*",
"layers.mlp.layer_1": "1:10:3, 9",
"layers.mlp.layer_2": "5:7",
}
)


def test_add_linear_biases_bool_true(config_with_true_biases):
"""Test case for add_linear_biases set to True (default)."""
assert config_with_true_biases._parsed_add_linear_biases == True


def test_add_linear_biases_bool_false(config_with_false_biases):
"""Test case for add_linear_biases set to False."""
assert config_with_false_biases._parsed_add_linear_biases == False


def test_add_linear_biases_dict_valid(config_with_dict_biases):
"""Test case for add_linear_biases with valid dictionary."""
assert config_with_dict_biases._parsed_add_linear_biases == {
TransformerSubLayerKeys.attn_query: {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
TransformerSubLayerKeys.mlp_layer1: {2, 5, 8, 10},
TransformerSubLayerKeys.mlp_layer2: {6, 7},
}


def test_invalid_key_in_dict():
"""Test case where an invalid key is provided in add_linear_biases dictionary."""
with pytest.raises(AssertionError):
# Using an invalid key in the dictionary.
TransformerArchitectureConfig(add_linear_biases={"invalid_key": "*"})


def test_invalid_range_format():
"""Test case where invalid range format is provided."""
with pytest.raises(AssertionError):
TransformerArchitectureConfig(add_linear_biases={TransformerSubLayerKeys.mlp_layer1: "1:10:3, abc"})


def test_empty_add_linear_biases():
"""Test case for empty add_linear_biases dictionary."""
with pytest.raises(AssertionError): # Expecting AssertionError for invalid empty dictionary
TransformerArchitectureConfig(add_linear_biases={})


def test_should_add_linear_bias_for_layer_and_sublayer(config_with_dict_biases):
"""Test case for should_add_linear_bias based on layer index and sublayer key."""

# Layer 1 and sublayer mlp_layer1
assert config_with_dict_biases.should_add_linear_bias(1, TransformerSubLayerKeys.mlp_layer1) == False

# Layer 2 and sublayer mlp_layer1
assert config_with_dict_biases.should_add_linear_bias(2, TransformerSubLayerKeys.mlp_layer1) == True

# Layer 9 and sublayer mlp_layer1
assert config_with_dict_biases.should_add_linear_bias(9, TransformerSubLayerKeys.mlp_layer1) == False

# Layer 6 and sublayer mlp_layer2
assert config_with_dict_biases.should_add_linear_bias(6, TransformerSubLayerKeys.mlp_layer2) == True

# Layer 5 and sublayer attn_query
assert config_with_dict_biases.should_add_linear_bias(5, TransformerSubLayerKeys.attn_query) == True


def test_should_add_linear_bias_for_bool_true(config_with_true_biases):
"""Test case for add_linear_biases set to True (should always return True)."""
assert config_with_true_biases.should_add_linear_bias(10, TransformerSubLayerKeys.mlp_layer1) == True


def test_should_add_linear_bias_for_bool_false(config_with_false_biases):
"""Test case for add_linear_biases set to False (should always return False)."""
assert config_with_false_biases.should_add_linear_bias(10, TransformerSubLayerKeys.mlp_layer1) == False

Loading