diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 44ef86f4..912ddaf5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,7 @@ jobs: run: | pip install "torch>=2.2.2" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 91118ca4..549140ca 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -31,7 +31,7 @@ jobs: - run: | pip install "torch>=2.2.2" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 7284ca07..5324ffeb 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -43,6 +43,7 @@ class ActivationType(str, enum.Enum): silu = "silu" relu = "relu" squared_relu = "squared_relu" + identity = "identity" @property def activation_fn(self) -> typing.Callable[["torch.Tensor"], "torch.Tensor"]: @@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None: ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), + ActivationType.identity: lambda x: x, } @@ -80,6 +82,7 @@ def _set_activation_fn_map() -> None: ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", + ActivationType.identity: "identity", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 8ab275ab..5b220b1a 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -119,6 +119,10 @@ def triton_mlp_activation_backward_kernel( grad = 2 * relu_out if gated or recompute: out = relu_out * relu_out + elif activation_type == _TritonActivationType.identity: + grad = 1 + if gated or recompute: + out = input_ else: raise NotImplementedError() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 302c0983..b4b4e187 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.ssm.config import SSMArchitectureConfig, SSMConfig from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig from fast_llm.utils import Assert @@ -43,6 +44,13 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) + + ssm: SSMArchitectureConfig = Field( + default_factory=SSMArchitectureConfig, + desc="Configuration for the transformer architecture.", + hint=FieldHint.core, + ) + max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -125,6 +133,8 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + ssm: SSMConfig = FieldUpdate(default_factory=SSMConfig) + init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py new file mode 100644 index 00000000..984858fc --- /dev/null +++ b/fast_llm/layers/ssm/config.py @@ -0,0 +1,135 @@ +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig +from fast_llm.utils import Assert + + +class SSMDimNames: + model_dim = "model_dim" # Model dimension (D) + state_dim = "state_dim" # State dimension (N) + conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers + inner_dim = "inner_dim" # Inner dimension after expansion + dt_rank = "dt_rank" # Rank of Δ + inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba + inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 + x_proj_dim = "x_proj_dim" # X projection dimension + head_dim = "head_dim" # Dimension of the mamba2 head (P) + conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers + qk_heads = "qk_heads" # Number of QK heads + v_heads = "v_heads" # Number of V heads + + +@config_class() +class SSMArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + # Normalization + normalization: NormalizationArchitectureConfig = Field( + default_factory=NormalizationArchitectureConfig, + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.core, + ) + + expansion_factor: int = Field( + default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.core, valid=check_field(Assert.gt, 0) + ) + + state_size: int = Field( + default=16, + desc="State size for Mamba blocks.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + conv_kernel_dimension: int = Field( + default=4, + desc="Conv kernel dimension for Mamba blocks.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + # Layer parameters + add_bias_linear: bool = Field( + default=False, + desc="Whether to use bias in SSM layers", + hint=FieldHint.core, + ) + + dt_rank: int = Field( + default=None, + desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", + hint=FieldHint.core, + ) + + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.core, + ) + + n_qk_heads: int = Field( + default=32, + desc="Number of QK heads for Mamba2 blocks.", + hint=FieldHint.core, + ) + + n_v_heads: int = Field( + default=32, + desc="Number of V heads for Mamba2 blocks.", + hint=FieldHint.core, + ) + + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.activation_type is None: + self.activation_type = ActivationType.silu + if self.dt_rank is None: + self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation + + super()._validate() + + +@config_class() +class SSMConfig(SSMArchitectureConfig, BaseModelConfig): + """Configuration for a Structured State Space Model (SSM) layer.""" + + normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) + + debug_ssm: bool = Field( + default=False, + desc="debug_ssm", + hint=FieldHint.optional, + ) + + dt_min: float = Field( + default=0.001, + desc="Minimum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self) -> None: + """Validate configuration parameters.""" + + super()._validate() + Assert.geq(self.dt_max, self.dt_min) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py new file mode 100644 index 00000000..f233d2f4 --- /dev/null +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -0,0 +1,214 @@ +import math + +import causal_conv1d +import einops +import mamba_ssm.ops.triton.ssd_combined +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ + +""" +This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +""" + + +def bias_init_method(conv_weight): + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + return init_uniform_(-bound, bound) + + +class DiscreteMamba2(torch.nn.Module): + """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" + + def __init__( + self, + config: SSMConfig, + layer_idx: int, + tensor_space: TensorSpace, + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + TODO: check what this comment means + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". + + Other options are all experimental and should not need to be configured. + """ + # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} + super().__init__() + self.config: SSMConfig = config + bias = config.add_bias_linear + self.layer_idx = layer_idx + + td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) + td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) + td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) + + self.d_model = td_model.size + self.d_inner = td_inner.size + self.d_state = td_state.size + self.chunk_size = config.chunk_size + self.n_qk_heads = td_n_qk_heads.size + self.n_v_heads = td_n_v_heads.size + self.conv_kernel_size = td_conv_kernel.size + + self.act = config.activation_type.activation_fn + self.activation_name = config.activation_type.name + + # TODO: double check innitializations + # Projections + self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) + self.z_bias = ( + ParameterMeta.from_dims( + (td_inner,), + weight_decay=False, + init_method=init_zeros_, + ) + if not bias + else 0.0 + ) + + # Convolutional layer + self.conv1d_weight = ParameterMeta.from_dims( + (td_conv, TensorDim("1", 1), td_conv_kernel), + init_method=init_uniform_( + 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) + ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + ) + self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) + + # D "skip" parameter + self.D = ParameterMeta.from_dims( + (td_n_qk_heads,), + weight_decay=False, + init_method=init_ones_, + ) + + # out_proj + self.out_proj = Linear( + td_inner, + td_model, + bias=bias, + weight_init_method=kaiming_init_(td_inner.size), + ) + + @property + def d_output(self): + """Returns the output dimension of the model.""" + return self.d_model + + @property + def state_to_tensor(self): + """Returns the state of the model as a tensor.""" + return self.layer.state_to_tensor + + def forward(self, hidden_states, kwargs): + """ + ON variable names and pep8: keeping some variable names as in the original code for clarity. + + Args: + u: (B, L, D), + + Returns: + outputs: dict. + outputs["hidden_states"]: (B, L, D). + outputs["state"]: inference cache. + """ + u = hidden_states + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = torch.nn.functional.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_( + torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) + ) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = einops.rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + # TODO: since we do not support inference for now, we only return the hidden states for now. + return outputs["hidden_states"].contiguous(), None + + def convolutional_forward(self, xBC, padded_len): + """Convolutional layer forward pass for the full sequence.""" + xBC = causal_conv1d.causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_bias, + activation=None if self.activation_name == "identity" else self.activation_name, + ).transpose(1, 2) + return xBC diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py new file mode 100644 index 00000000..22135638 --- /dev/null +++ b/fast_llm/layers/ssm/llamba_block.py @@ -0,0 +1,34 @@ +import typing + +from fast_llm.layers.transformer.transformer import BaseBlock + +if typing.TYPE_CHECKING: + from fast_llm.engine.config_utils.tensor_space import TensorSpace + from fast_llm.layers.ssm.config import SSMConfig + from fast_llm.layers.transformer.config import TransformerConfig + + +class LlambaBlock(BaseBlock): + """ + A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 + """ + + name = "Llamba block" + _mixer_module_name = "mixer" + + def __init__( + self, + config_transformer: "TransformerConfig", + config_ssm: "SSMConfig", + tensor_space: "TensorSpace", + mixer_cls, + layer_index: int, + return_input: bool = False, + ): + self.mixer_cls = mixer_cls + self._config_ssm = config_ssm + self._debug_mode = self._config_ssm.debug_ssm + super().__init__(config_transformer, tensor_space, layer_index, return_input) + + def _create_mixer(self): + self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py new file mode 100644 index 00000000..1695cf2f --- /dev/null +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -0,0 +1,173 @@ +import math +from typing import Callable + +import einops +import mamba_ssm.ops.selective_scan_interface +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ + +""" +Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. +For now it only supports training and not inference. +This works with triton 3.1.0 +""" + + +def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + # S4D real initialization + # TODO: adopt this innitialization to work for tensor parallel setting! + A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if tensor.shape != A_log.shape: + if tensor.numel() == A_log.numel(): + tensor_view = tensor.view(d_inner, d_state) + tensor_view.copy_(A_log) + else: + raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") + else: + tensor.copy_(A_log) + return tensor + + return init_ + + +def init_dtprojbias( + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict +) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + dt = torch.exp( + torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + tensor.copy_(inv_dt) + return tensor + + return init_ + + +class MambaLayer(torch.nn.Module): + def __init__( + self, + config: SSMConfig, + layer_idx: int, + tensor_space: TensorSpace, + ): + factory_kwargs = {} + super().__init__() + self.config: SSMConfig = config + self.layer_idx = layer_idx + + self._debug_mode = config.debug_ssm + + # Tensor dims: + td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + td_inner_proj = tensor_space.get_tensor_dim( + SSMDimNames.inner_proj_mamba + ) # TensorDim("D_inner_2", self.d_inner * 2) + tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) + td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) + td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) + td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + self.d_conv = td_conv_kernel.size + self.d_inner = td_inner.size + self.d_state = td_state.size + self.d_model = td_model.size + self.dt_rank = tdt_rank.size + + self.in_proj_weight = ParameterMeta.from_dims( + (td_inner_proj, td_model), + init_method=kaiming_init_(td_model.size), + ) + + self.conv1d_weight = ParameterMeta.from_dims( + (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + init_method=kaiming_init_(td_inner.size), + ) + + self.conv1d_bias = None + + self.activation = "silu" + self.act = torch.nn.SiLU() + + self.x_proj = Linear( + td_inner, + td_x_proj, + weight_init_method=kaiming_init_(td_inner.size), + bias=False, + **factory_kwargs, + ) + self.x_proj.weight.auto_grad_accumulation = True + + # TODO: the weights are innitialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 + self.dt_proj_weight = ParameterMeta.from_dims( + (td_inner, tdt_rank), + init_method=kaiming_init_(tdt_rank.size), + ) + + self.dt_proj_bias = ParameterMeta.from_dims( + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + ), + ) + + self.A_log = ParameterMeta.from_dims( + (td_inner, td_state), + weight_decay=False, + init_method=init_A(self.d_state, self.d_inner), + ) + + # D "skip" parameter + self.D = ParameterMeta.from_dims( + (td_inner,), + weight_decay=False, + init_method=init_ones_, + ) + + self.out_proj = Linear( + td_inner, + td_model, + bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. + weight_init_method=kaiming_init_(td_model.size), + **factory_kwargs, + ) + self.out_proj.weight.auto_grad_accumulation = True + + def forward(self, hidden_states, kwargs): + batch, seqlen, dim = hidden_states.shape + + # We do matmul and transpose BLH -> HBL at the same time + xz = einops.rearrange( + self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + if self._debug_mode: + print("XZ: ", xz.shape) + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s + out = mamba_ssm.ops.selective_scan_interface.mamba_inner_fn( + xz, + self.conv1d_weight, + self.conv1d_bias, + self.x_proj.weight, + self.dt_proj_weight, + self.out_proj.weight, + self.out_proj.bias, # is None here + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj_bias.float(), + delta_softplus=True, + ) + return out, None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 311403fc..92df1893 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -1,3 +1,4 @@ +import abc import logging import typing @@ -17,33 +18,31 @@ logger = logging.getLogger(__name__) -class TransformerLayer(Layer): +class BaseBlock(Layer, abc.ABC): """ - A transformer decoder layer. + A transformer-like decoder base block block with abstract mixer. """ + name = "Transformer layer" + _mixer_module_name = "self_attn" + def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, + self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() - self._config = config - self._tensor_space = tensor_space - self._dropout_p = self._config.hidden_dropout + self._config: TransformerConfig = config + self._tensor_space: TensorSpace = tensor_space + self._dropout_p: float = self._config.hidden_dropout # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input = return_input + self._return_input: bool = return_input self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self.self_attn = Attention(self._config, self._tensor_space, layer_index) + self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.name} mlp" @@ -53,6 +52,10 @@ def __init__( self.norm_1 = self._config.peft.apply_other(self.norm_1) self.norm_2 = self._config.peft.apply_other(self.norm_2) + @abc.abstractmethod + def _create_mixer(self): + pass + @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor @@ -63,7 +66,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"Transformer layer {self._layer_index}" + return f"{self._name} {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -111,13 +114,13 @@ def forward( hidden_states = self.norm_1(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = self.self_attn(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) if self._debug_mode: - self._debug_log(hidden_states, "Attn output", kwargs, bias=bias) + self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) if self._debug_mode: - self._debug_log(input_, "Attn residual", kwargs) + self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) hidden_states = self.norm_2(input_) if self._debug_mode: self._debug_log(hidden_states, "Norm 2", kwargs) @@ -131,3 +134,16 @@ def forward( if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states + + +class TransformerLayer(BaseBlock): + name = "Transformer layer" + _mixer_module_name = "self_attn" + + def __init__( + self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + ): + super().__init__(config, tensor_space, layer_index, return_input) + + def _create_mixer(self): + self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 905552a8..8f16aaea 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig +from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig from fast_llm.utils import Registry model_registry = Registry[str, FastLLMModelConfig]( @@ -11,6 +12,7 @@ for model in [ GPTModelConfig, CustomModelConfig, + HybridSSMModelConfig, ] }, ) @@ -22,6 +24,7 @@ for trainer in [ GPTTrainerConfig, CustomTrainerConfig, + HybridTrainerConfig, ] }, ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py new file mode 100644 index 00000000..b38467d3 --- /dev/null +++ b/fast_llm/models/ssm/config.py @@ -0,0 +1,159 @@ +import logging +import math +import typing + +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig +from fast_llm.layers.ssm.config import SSMDimNames +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.models.ssm.model import HybridSSMModel + +logger = logging.getLogger(__name__) + + +@config_class +class HybridSSMArchitectureConfig(LanguageModelArchitectureConfig): + _abstract = False + + hybrid_block_layout: list[str] = Field( + default_factory=lambda: ["m2"], + desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.", + hint=FieldHint.core, + ) + + +@config_class() +class HybridSSMBaseModelConfig(LanguageModelBaseConfig, HybridSSMArchitectureConfig): + architecture_class = HybridSSMArchitectureConfig + + use_megatron_initialization: bool = Field( + default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing + ) # TODO: is this needed? + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + """ + Setup the tensor space for the model. + Some of these can be setup directly in the layer config, but keeping them here for clarity. + """ + super().setup_tensor_space(tensor_space) + if not "m2" in self.hybrid_block_layout and not "m" in self.hybrid_block_layout: + raise ValueError( + "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" + ) + + if self.ssm.dt_rank < 0: + mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) + else: + mamba_dt_rank = self.ssm.dt_rank + + d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) + # Mamba-specific dimensions + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, mamba_dt_rank)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, mamba_dt_rank + self.ssm.state_size * 2)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) + + if "m2" in self.hybrid_block_layout: + # Mamba2 specific dimensions + # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 + headdim = d_inner // self.ssm.n_v_heads + Assert.eq(self.ssm.n_v_heads, d_inner // headdim) + Assert.eq(d_inner % headdim, 0) + Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) + + conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads + + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) + + def _validate(self): + if len(self.hybrid_block_layout) != self.transformer.num_layers: + len_block_layout = len(self.hybrid_block_layout) + if self.transformer.num_layers % len_block_layout != 0: + raise ValueError( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + ) + num_repeats = int(self.transformer.num_layers // len_block_layout) + logger.warning( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + ) + self.hybrid_block_layout = self.hybrid_block_layout * num_repeats + + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) + Assert.custom( + lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), + f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", + ) + + super()._validate() + + +class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llamba" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LLambaHuggingfaceCheckpointHandler + + return LLambaHuggingfaceCheckpointHandler + + +@config_class() +class HybridSSMModelConfig(FastLLMModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "hybrid_ssm" + base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) + + @classmethod + def get_model_class(cls) -> type["HybridSSMModel"]: + from fast_llm.models.ssm.model import HybridSSMModel + + return HybridSSMModel + + @classmethod + def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM + + return HuggingfaceHybridSSMModelForCausalLM + + def _validate(self): + logger.warning( + "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." + ) + super()._validate() + + +@config_class() +class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): + _abstract = False + model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + + +@config_class() +class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): + data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + + @classmethod + def get_trainer_class(cls) -> type["SSMTrainer"]: + from fast_llm.models.ssm.trainer import SSMTrainer + + return SSMTrainer diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py new file mode 100644 index 00000000..190b2ffa --- /dev/null +++ b/fast_llm/models/ssm/conversion.py @@ -0,0 +1,284 @@ +import json +import os +import pathlib +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + ConstantImportParamConverter, + IgnoreImportWeightConverter, + MappedConfigParamConverter, + ParamConverter, + RenameParamConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationType +from fast_llm.models.gpt.conversion import MLPLayer2Converter +from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + + +class LLambaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + """ + Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json + """ + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("n_layer",),), + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + # TODO: is there an equivalen of pad_vocab_size_multiple in FastLLM, does it matter? + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + RenameParamConverter( + fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_embeddings",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=( + ( + "mlp_cfg", + "intermediate_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), + export_names=( + ( + "mlp_cfg", + "bias", + ), + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=( + ( + "mlp_cfg", + "act_fn", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("ssm", "state_size"),), + export_names=( + ( + "ssm_cfg", + "d_state", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "n_v_heads"),), + export_names=( + ( + "ssm_cfg", + "n_v_heads", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "n_qk_heads"),), + export_names=( + ( + "ssm_cfg", + "n_qk_heads", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "expansion_factor"),), + export_names=( + ( + "ssm_cfg", + "expand", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "chunk_size"),), + export_names=( + ( + "ssm_cfg", + "chunk_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "add_bias_linear"),), + export_names=( + ( + "ssm_cfg", + "bias", + ), + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("ssm", "activation_type"),), + export_names=( + ( + "ssm_cfg", + "activation", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + ) + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"backbone.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"backbone.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py new file mode 100644 index 00000000..77cd346f --- /dev/null +++ b/fast_llm/models/ssm/huggingface.py @@ -0,0 +1,21 @@ +import logging + +from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.ssm.config import HybridSSMModelConfig +from fast_llm.models.ssm.model import HybridSSMModel + +logger = logging.getLogger(__name__) + + +class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): + model_type = "fast_llm_ssm" + model_config_class = HybridSSMModelConfig + fast_llm_config: HybridSSMModelConfig + + +class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): + config_class = HuggingfaceSSMModelConfig + config: HuggingfaceSSMModelConfig + model_class = HybridSSMModel + _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py new file mode 100644 index 00000000..33d2c185 --- /dev/null +++ b/fast_llm/models/ssm/model.py @@ -0,0 +1,91 @@ +import logging +import typing + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 +from fast_llm.layers.ssm.llamba_block import LlambaBlock +from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig + +logger = logging.getLogger(__name__) + + +class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): + """ + A hybrid model that interleaves Transformer and Mamba blocks. + Right now only LlambaBlock is supported. + AS for the mixer, transformer uses MHA. For the LLlambaBlock we support Mamba1 and descrete mamba2. + """ + + config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig + _is_setup: bool = False + + def __init__( + self, + config: HybridSSMBaseModelConfig, + distributed_config: DistributedConfig, + ): + self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed + super().__init__(config, distributed_config) + + def get_layers(self) -> list[Layer]: + """ + Create a list of layers for the model, interleaving Transformer and Mamba blocks + according to the block pattern. + """ + layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + + # Create blocks according to pattern + for i, block_type in enumerate(self._config.hybrid_block_layout): + if block_type == "t": + # Transformer block + layers.append( + TransformerLayer( + self._config.transformer, + self._tensor_space, + layer_index=i + 1, + ) + ) + elif block_type == "m2": + mamba_block = self.SSM_BLOCK_CLS( + config_transformer=self._config.transformer, + config_ssm=self._config.ssm, + mixer_cls=DiscreteMamba2, + layer_index=i + 1, + tensor_space=self._tensor_space, + ) + layers.append(mamba_block) + + elif block_type == "m": + # Create Mamba block + mamba_block = self.SSM_BLOCK_CLS( + config_transformer=self._config.transformer, + config_ssm=self._config.ssm, + mixer_cls=MambaLayer, + layer_index=i + 1, + tensor_space=self._tensor_space, + ) + layers.append(mamba_block) + + else: + raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + + # Add the language model head + layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)) + + return layers + + +class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]): + """ + A hybrid model that combines Transformer and SSM blocks. + """ + + config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig + base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py new file mode 100644 index 00000000..c0e5be26 --- /dev/null +++ b/fast_llm/models/ssm/trainer.py @@ -0,0 +1,10 @@ +import typing + +from fast_llm.models.gpt.trainer import GPTTrainer +from fast_llm.models.ssm.config import HybridTrainerConfig +from fast_llm.models.ssm.model import HybridSSMModel + + +class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): + config_class: typing.ClassVar[type[HybridTrainerConfig]] = HybridTrainerConfig + model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index f59927b6..84930756 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -325,6 +326,10 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ +def kaiming_init_(d_in): + return init_normal_(0.0, math.sqrt(2.0 / d_in)) + + def init_uniform_( low=0.0, high=1.0, min_val=None, max_val=None ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index e9df18ed..0cc02f42 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -28,9 +28,13 @@ def fast_llm(args=None): raise RuntimeError("Unknown subcommand") Runnable.parse_and_run(unparsed) except ValidationError: + if sys.gettrace(): + raise log_main_rank(traceback.format_exc(), log_fn=logger.error) sys.exit(1) except Exception: # noqa + if sys.gettrace(): + raise logger.critical(traceback.format_exc()) sys.exit(1) diff --git a/setup.cfg b/setup.cfg index c21f02a7..9b944b27 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,8 @@ CORE = safetensors>=0.4.4 # Update the base image (version fixed to ensure there is a wheel for the base image), may need --no-build-isolation flash-attn==2.7.2.post1 + mamba_ssm[causal-conv1d]==2.2.4 + # Required for some optional features and tools. OPTIONAL = diff --git a/tests/common.py b/tests/common.py index dfdee964..5bd9563f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -21,6 +21,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -201,6 +202,10 @@ ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] +CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] +CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM + _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), "sc1": ("gpt", CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), @@ -253,6 +258,13 @@ CONFIG_MIXTRAL_COMMON, MixtralGPTHuggingfaceCheckpointFormat, ), + "llamba": ( + "hybrid_ssm", + CONFIG_LLAMBA_FAST_LLM, + CONFIG_LLAMBA_MEGATRON, + CONFIG_LLAMBA_COMMON, + LLambaHuggingfaceCheckpointFormat, + ), "mixtral-yarn": ( "gpt", CONFIG_MIXTRAL_YARN_FAST_LLM, @@ -269,7 +281,6 @@ ), } - TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] diff --git a/tests/test_config.py b/tests/test_config.py index 79c6738d..b954b61c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -145,6 +145,7 @@ def test_pretrained_config(load_config: ModelConfigType): "activation_type": "silu", # Implicit default, non-default value "head_groups": 4, }, + "ssm": {"dt_rank": -1, "activation_type": "silu"}, "tie_word_embeddings": False, }, "multi_stage": {"zero_stage": 3}, @@ -165,6 +166,7 @@ def test_pretrained_config(load_config: ModelConfigType): "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, + "ssm": {"dt_rank": 10, "activation_type": "silu"}, "vocab_size": 1000, } pretrained_config = PretrainedGPTModelConfig.from_dict( @@ -195,6 +197,7 @@ def test_pretrained_config(load_config: ModelConfigType): "activation_type": "silu", "head_groups": 1, }, + "ssm": {"dt_rank": 10, "activation_type": "silu"}, "tie_word_embeddings": False, "vocab_size": 1000, } diff --git a/tests/test_ssms.py b/tests/test_ssms.py new file mode 100644 index 00000000..5863f903 --- /dev/null +++ b/tests/test_ssms.py @@ -0,0 +1,340 @@ +import pathlib +from functools import partial + +import pytest +import torch + +from fast_llm.config import NoAutoValidate +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat + +try: + from fast_llm.layers.ssm.config import SSMConfig + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + from fast_llm.layers.ssm.llamba_block import LlambaBlock + from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMBaseModelConfig, HybridSSMModel +except ImportError: + MambaLayer, LlambaBlock, HybridSSMBaseModel, HybridSSMBaseModelConfig, DiscreteMamba2 = ( + None, + None, + None, + None, + None, + ) + # Mamba not isntalled, skipping tests + +try: + from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel +except ImportError: + LMHeadModel = None + +run_test = MambaLayer is not None and torch.cuda.is_available() + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def get_hybrid_config(hybrid_block_layout=["t", "m", "t", "m"]): + config = HybridSSMBaseModelConfig( + transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), + ssm=SSMConfig(), + hybrid_block_layout=hybrid_block_layout, + init_method_std_embed=0.02, + init_method_min_embed=-0.02, + init_method_max_embed=0.02, + use_position_embeddings=True, + tie_word_embeddings=False, + ) + return config + + +def get_hf_llamba_out(input_ids, path, format): + if format == LLambaHuggingfaceCheckpointFormat: + from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel + elif format == LlamaGPTHuggingfaceCheckpointFormat: + from transformers import LlamaForCausalLM as LMHeadModel + else: + raise ValueError(f"Invalid format: {format}") + + model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum + + +@pytest.mark.slow +@pytest.mark.skipif( + not run_test or LMHeadModel is None, + reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", +) +def test_load_from_llamba_checkpoint(distributed_config): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 + + path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") + format = LLambaHuggingfaceCheckpointFormat + + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) + hf_logits = hf_logits["logits"].cpu() + + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridSSMModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + + # model = GPTModel.from_pretrained(checkpoint_config) + assert model.config.base_model.vocab_size == vocab_size + schedule_config = ScheduleConfig() + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) + batch_config.setup(distributed_config) + batch_config.validate() + schedule_runner = ScheduleRunner( + config=schedule_config, + multi_stage=model, + distributed_config=model.distributed.config, + ) + schedule = Schedule( + multi_stage=model, + batch_config=batch_config, + schedule_config=schedule_config, + distributed_config=model.distributed.config, + phase=PhaseType.inference, + ) + schedule_runner.setup(model.distributed, optimizer=None) + + common_kwargs = { + TransformerKwargs.sequence_first: True, + TransformerKwargs.grad_output: False, + } + input_data = [(x, common_kwargs)] + + losses, success, metrics = schedule_runner.run_step( + iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True + ) + + logits = input_data[0][1]["logits"].cpu() + assert torch.allclose(logits, hf_logits, atol=1e-2) + + +@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") +@pytest.mark.parametrize( + "hybrid_block_layout,LAYER_CLS", + [ + (["m", "t"], MambaLayer), + (["m2", "t"], DiscreteMamba2), + ], + ids=["mamba", "descrete_mamba2"], +) +def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): + hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) + tensor_space = TensorSpace(distributed_config=distributed_config) + hybrid_config.setup_tensor_space(tensor_space) + layer = LAYER_CLS(hybrid_config.ssm, layer_idx=0, tensor_space=tensor_space) + tensor_space.setup(distributed) + materialize_meta_tensors(layer, tensor_space) + layer.to(distributed.device) + + batch_size = 2 + seq_length = 32 + hidden_size = hybrid_config.transformer.hidden_size + x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) + + # Run forward pass + output, _ = layer(x, {}) + + loss = output.sum() + loss.backward() + # Basic shape checkss + assert output.shape == x.shape + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + +@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") +def test_mamba_block(distributed_config, distributed): + hybrid_config = get_hybrid_config(hybrid_block_layout=["m", "t"]) + tensor_space = TensorSpace(distributed_config=distributed_config) + tensor_space.setup(distributed) + hybrid_config.setup_tensor_space(tensor_space) + layer_idx = 0 + + mixer_cls = partial(MambaLayer, layer_idx=layer_idx) + block = LlambaBlock( + hybrid_config.transformer, + hybrid_config.ssm, + mixer_cls=mixer_cls, + tensor_space=tensor_space, + layer_index=layer_idx, + ) + + materialize_meta_tensors(block, tensor_space) + block.to("cuda") + + batch_size = 2 + seq_length = 32 + hidden_size = hybrid_config.transformer.hidden_size + x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) + + hidden_states = block(x, {}) + loss = hidden_states.sum() + loss.backward() + + assert hidden_states.shape == x.shape + assert not torch.isnan(hidden_states).any() + assert not torch.isinf(hidden_states).any() + + +@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") +@pytest.mark.parametrize( + "hybrid_block_layout", + [ + (["m", "t"]), + (["m2", "t"]), + ], + ids=["mamba", "descrete_mamba2"], +) +def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): + hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) + model = HybridSSMBaseModel(hybrid_config, distributed_config) + distributed = Distributed(distributed_config) + model.setup(distributed) + tensor_space = model._tensor_space + materialize_meta_tensors(model, tensor_space) + model.to("cuda") + + batch_size = 2 + seq_length = 32 + x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") + position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) + attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape + labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") + losses = {LanguageModelLossNames.language_model_loss: []} + output = model( + x, + { + "position_ids": position_ids, + TransformerKwargs.sequence_first: True, + TransformerKwargs.attention_mask: attention_mask, + TransformerKwargs.attention_mask_value: -100, + TransformerKwargs.grad_output: True, + LanguageModelKwargs.labels: labels, + }, + losses=losses, + ) + loss = sum(losses[LanguageModelLossNames.language_model_loss]) + loss.backward() + + +# TODO: added this when inference enabled +# No inference for now +# @dataclass +# class InferenceParams: +# max_seqlen: int +# max_batch_size: int +# sequence_len_offset: int = 0 +# key_value_memory_dict: dict = None + +# def __post_init__(self): +# if self.key_value_memory_dict is None: +# self.key_value_memory_dict = {} + + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") +# def test_hybrid_model_inference(distributed_config, hybrid_config): +# hybrid_config.ssm.use_fast_path = False +# model = HybridSSMBaseModel(hybrid_config, distributed_config) +# distributed = Distributed(distributed_config) +# model.setup(distributed) +# tensor_space = model._tensor_space +# materialize_meta_tensors(model, tensor_space) +# model.to("cuda") +# # print(model) + +# batch_size = 2 +# seq_length = 32 +# x = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") +# position_ids = torch.arange(seq_length, device="cuda", dtype=torch.int64) +# attention_mask = torch.ones((1, 1, 1, 1), device="cuda", dtype=torch.bool) # will be broadcasted to right shape +# labels = torch.randint(0, 49152, (batch_size, seq_length), device="cuda") +# max_new_tokens = 10 + +# inference_params = InferenceParams( +# max_seqlen=len(x[0]) + max_new_tokens, max_batch_size=x.shape[0], sequence_len_offset=0 +# ) +# losses = {LanguageModelLossNames.language_model_loss: []} + +# output = model( +# x, +# { +# "position_ids": position_ids, +# TransformerKwargs.sequence_first: True, +# TransformerKwargs.attention_mask: attention_mask, +# TransformerKwargs.attention_mask_value: -100, +# TransformerKwargs.grad_output: True, +# LanguageModelKwargs.labels: labels, +# "inference_params": inference_params, +# }, +# losses=losses, +# ) diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 1ace81d7..11804631 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -169,7 +169,14 @@ def test_triton_normalization(has_bias, zero_centered): @requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( - "activation_type", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] + "activation_type", + [ + ActivationType.gelu, + ActivationType.silu, + ActivationType.relu, + ActivationType.squared_relu, + ActivationType.identity, + ], ) @pytest.mark.parametrize("recompute", [True, False]) def test_triton_mlp_activation(gated, activation_type, recompute):