From b50ff5993a5d8b2a3d8c7558e81684f8803b044a Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:41:17 +0200 Subject: [PATCH] [`Mamba2`] Move dt calculations to kernel (#33520) * use kernel for dt calculations * add small test * [run-slow] mamba2 --- .../models/mamba2/modeling_mamba2.py | 3 ++- tests/models/mamba2/test_modeling_mamba2.py | 26 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 69390ea9ad2b8d..7b414ff9570d9a 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -358,7 +358,6 @@ def cuda_kernels_forward( dim=-1, ) - time_step = nn.functional.softplus(time_step + self.dt_bias) # 1D Convolution if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( @@ -391,6 +390,8 @@ def cuda_kernels_forward( z=None, seq_idx=None, return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, **dt_limit_kwargs, ) if ssm_state is not None and cache_params is not None: diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index a1e2138d4d6d78..55c18abe6b96af 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -35,7 +35,7 @@ Mamba2ForCausalLM, Mamba2Model, ) - from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache + from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 else: is_torch_greater_or_equal_than_2_0 = False @@ -378,3 +378,27 @@ def test_batched_equivalence_without_cache(self): individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) + + @slow + @require_torch_gpu + def test_mamba2_mixer_train_vs_eval_equivalence(self): + # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 + # Credit to zhixuan-lin + + B, T, D = 4, 512, 768 + dtype = torch.bfloat16 + config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) + + torch.manual_seed(42) + with torch.amp.autocast(device_type="cuda", dtype=dtype): + with torch.no_grad(): + mixer = Mamba2Mixer(config, layer_idx=0).to("cuda") + hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda") + + mixer.train() + out_train = mixer(hidden_states) + + mixer.eval() + out_eval = mixer(hidden_states) + + self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))