Skip to content

[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ad2b8d5
wip
oleksost Feb 18, 2025
5dbc72a
wip
oleksost Mar 13, 2025
5213e9e
WIP
oleksost Mar 14, 2025
01d0fe4
mamba1 block
oleksost Mar 20, 2025
ae8f3ca
removed build from remote
oleksost Mar 20, 2025
fa6d3bc
removed unneccesary tests
oleksost Mar 20, 2025
963e674
removed unneccesary files
oleksost Mar 20, 2025
0842fcb
test
oleksost Mar 20, 2025
1c81719
test mamba1
oleksost Mar 24, 2025
37ba0d5
tensor dimentions
oleksost Mar 24, 2025
11a5db3
meta init with full model run
oleksost Mar 24, 2025
4af7eb7
training, but having backward issues
oleksost Mar 25, 2025
be93749
integration into training pipeline
oleksost Mar 30, 2025
dd469bc
mamba2
oleksost Mar 31, 2025
ebe1b75
renamed config + skip test
oleksost Mar 31, 2025
a4400fd
skip tests if mamba not installed
oleksost Mar 31, 2025
c49148c
pre-commits
oleksost Mar 31, 2025
5c8d930
cleanup
oleksost Mar 31, 2025
ef6791b
dependencies
oleksost Mar 31, 2025
f03dd10
descrete mamba2
oleksost Mar 31, 2025
2414252
Merge branch 'ssm_mamba2' into ssm
oleksost Mar 31, 2025
f4d411d
test
oleksost Mar 31, 2025
ee86c68
llamba checkpoint converter
oleksost Apr 3, 2025
2561738
cleanup
oleksost Apr 4, 2025
ad8a48c
test
oleksost Apr 4, 2025
5243a88
Merge branch 'main' into ssm
oleksost Apr 4, 2025
075a31f
mamba force build
oleksost Apr 7, 2025
a788989
mamba force build
oleksost Apr 7, 2025
2700660
mamba force build
oleksost Apr 7, 2025
baaf714
causal conv skip build
oleksost Apr 7, 2025
833b586
Merge branch 'main' into ssm
oleksost Apr 7, 2025
9e2897d
docs.yaml
oleksost Apr 7, 2025
b231cb8
MTP hardcoded
oleksost Apr 7, 2025
8ccaa28
import nvm
oleksost Apr 7, 2025
864fff2
remove dependency on cartesia
oleksost Apr 7, 2025
7f2b35f
save llamba
oleksost Apr 7, 2025
81c71af
addressed comments
oleksost Apr 8, 2025
7b7ce62
addressed comments
oleksost Apr 9, 2025
776e67b
Merge branch 'main' into ssm
oleksost Apr 9, 2025
3456884
nvm
oleksost Apr 10, 2025
b48f68d
renamed block pattern into block layout
oleksost Apr 11, 2025
9a35783
renames
oleksost Apr 11, 2025
32b8aa1
nvm
oleksost Apr 14, 2025
4f9aad0
wip
oleksost Apr 16, 2025
68de5d1
addressed comments
oleksost Apr 23, 2025
ebc516a
Merge branch 'main' into ssm
oleksost Apr 23, 2025
cb95e52
wip
oleksost Apr 23, 2025
79c9a4b
batch config
oleksost Apr 23, 2025
bb3ba66
clean up
oleksost Apr 23, 2025
a5297be
nvm
oleksost Apr 23, 2025
2d39857
tests
oleksost Apr 23, 2025
df032b5
nvm
oleksost Apr 23, 2025
c8fdbb9
identity activation into MLP
oleksost Apr 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"

- name: Run tests
run: pytest .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
3 changes: 3 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ActivationType(str, enum.Enum):
silu = "silu"
relu = "relu"
squared_relu = "squared_relu"
identity = "identity"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add support for it in the MLP? (I know it's triton but this one is trivial.)
https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/triton/mlp.py Or otherwise prevent it in the config?


@property
def activation_fn(self) -> typing.Callable[["torch.Tensor"], "torch.Tensor"]:
Expand Down Expand Up @@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None:
ActivationType.silu: torch.nn.functional.silu,
ActivationType.relu: torch.nn.functional.relu,
ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2),
ActivationType.identity: lambda x: x,
}


Expand All @@ -80,6 +82,7 @@ def _set_activation_fn_map() -> None:
ActivationType.silu: "silu",
ActivationType.relu: "relu",
ActivationType.squared_relu: "relu2",
ActivationType.identity: "identity",
}
_ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()}

Expand Down
4 changes: 4 additions & 0 deletions fast_llm/functional/triton/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def triton_mlp_activation_backward_kernel(
grad = 2 * relu_out
if gated or recompute:
out = relu_out * relu_out
elif activation_type == _TritonActivationType.identity:
grad = 1
if gated or recompute:
out = input_
else:
raise NotImplementedError()

Expand Down
10 changes: 10 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.functional.config import CrossEntropyImpl
from fast_llm.layers.ssm.config import SSMArchitectureConfig, SSMConfig
from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig
from fast_llm.utils import Assert

Expand Down Expand Up @@ -43,6 +44,13 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
desc="Configuration for the transformer architecture.",
hint=FieldHint.core,
)

ssm: SSMArchitectureConfig = Field(
default_factory=SSMArchitectureConfig,
desc="Configuration for the transformer architecture.",
hint=FieldHint.core,
)

max_position_embeddings: int = Field(
default=2048,
desc="Number of absolute position embeddings, if applicable.",
Expand Down Expand Up @@ -125,6 +133,8 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
architecture_class = LanguageModelArchitectureConfig

transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig)
ssm: SSMConfig = FieldUpdate(default_factory=SSMConfig)

init_method_std_embed: float = Field(
default=None,
desc="Initialization scale for the vocabulary embedding and output weights (logits).",
Expand Down
135 changes: 135 additions & 0 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig
from fast_llm.utils import Assert


class SSMDimNames:
model_dim = "model_dim" # Model dimension (D)
state_dim = "state_dim" # State dimension (N)
conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers
inner_dim = "inner_dim" # Inner dimension after expansion
dt_rank = "dt_rank" # Rank of Ξ”
inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba
inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2
x_proj_dim = "x_proj_dim" # X projection dimension
head_dim = "head_dim" # Dimension of the mamba2 head (P)
conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers
qk_heads = "qk_heads" # Number of QK heads
v_heads = "v_heads" # Number of V heads


@config_class()
class SSMArchitectureConfig(BaseModelArchitectureConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust field names for our naming conventions.

_abstract = False

# Normalization
normalization: NormalizationArchitectureConfig = Field(
default_factory=NormalizationArchitectureConfig,
desc="Configuration for the normalization layers architecture.",
hint=FieldHint.core,
)

expansion_factor: int = Field(
default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.core, valid=check_field(Assert.gt, 0)
)

state_size: int = Field(
default=16,
desc="State size for Mamba blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
conv_kernel_dimension: int = Field(
default=4,
desc="Conv kernel dimension for Mamba blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

# Layer parameters
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.core,
)

dt_rank: int = Field(
default=None,
desc="Rank of the Ξ” projection matrix. If 'None', will be set to ceil(hidden_size/16)",
hint=FieldHint.core,
)

chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.core,
)

n_qk_heads: int = Field(
default=32,
desc="Number of QK heads for Mamba2 blocks.",
hint=FieldHint.core,
)

n_v_heads: int = Field(
default=32,
desc="Number of V heads for Mamba2 blocks.",
hint=FieldHint.core,
)

activation_type: ActivationType = Field(
default=None,
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.core,
)

def _validate(self) -> None:
with self._set_implicit_default():
if self.activation_type is None:
self.activation_type = ActivationType.silu
if self.dt_rank is None:
self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation

super()._validate()


@config_class()
class SSMConfig(SSMArchitectureConfig, BaseModelConfig):
"""Configuration for a Structured State Space Model (SSM) layer."""

normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)

debug_ssm: bool = Field(
default=False,
desc="debug_ssm",
hint=FieldHint.optional,
)

dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

dt_max: float = Field(
default=0.1,
desc="Maximum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

def _validate(self) -> None:
"""Validate configuration parameters."""

super()._validate()
Assert.geq(self.dt_max, self.dt_min)
Loading