Skip to content

Commit

Permalink
[Mamba2] Move dt calculations to kernel (#33520)
Browse files Browse the repository at this point in the history
* use kernel for dt calculations

* add small test

* [run-slow] mamba2
  • Loading branch information
vasqu authored Sep 19, 2024
1 parent 162056a commit b50ff59
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def cuda_kernels_forward(
dim=-1,
)

time_step = nn.functional.softplus(time_step + self.dt_bias)
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
Expand Down Expand Up @@ -391,6 +390,8 @@ def cuda_kernels_forward(
z=None,
seq_idx=None,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
Expand Down
26 changes: 25 additions & 1 deletion tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Mamba2ForCausalLM,
Mamba2Model,
)
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else:
is_torch_greater_or_equal_than_2_0 = False
Expand Down Expand Up @@ -378,3 +378,27 @@ def test_batched_equivalence_without_cache(self):
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])

@slow
@require_torch_gpu
def test_mamba2_mixer_train_vs_eval_equivalence(self):
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
# Credit to zhixuan-lin

B, T, D = 4, 512, 768
dtype = torch.bfloat16
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)

torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")

mixer.train()
out_train = mixer(hidden_states)

mixer.eval()
out_eval = mixer(hidden_states)

self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))

0 comments on commit b50ff59

Please sign in to comment.