Skip to content

Commit

Permalink
Merge branch 'papakipos/mamba-unit-tests' into 'main'
Browse files Browse the repository at this point in the history
Add unit tests for Mamba hybrid model sub-units

See merge request ADLR/megatron-lm!2233
  • Loading branch information
jaredcasper committed Oct 30, 2024
2 parents 66cc8c0 + 2e4e0d9 commit 92ae1d7
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/unit_tests/models/test_mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,12 @@ def test_save_load(self, tmp_path):
torch.save(self.model.state_dict(), path)

self.model.load_state_dict(torch.load(path))

def test_layer_numbers(self):
"""
The layer numbers should start at one (for the embedding # layer) and go up
incrementally from there. This is required for PEFT to work.
"""
model = self.model
for expected, layer in enumerate(model.decoder.layers, start=1):
assert expected == layer.layer_number, "layer numbers are incorrect"
70 changes: 70 additions & 0 deletions tests/unit_tests/ssm/test_mamba_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from megatron.core.ssm.mamba_block import MambaStack
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
from megatron.core.ssm.mamba_layer import MambaLayer
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from tests.unit_tests.test_utilities import Utils


class TestMambaBlock:

def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
# Note that test_layer_types verifies these types and the ordering
hybrid_override_pattern = Symbols.MAMBA + Symbols.ATTENTION + Symbols.MLP
transformer_config = TransformerConfig(
hidden_size=256, # The Mamba layer places several constraints on this
# Need to specify num_attention_heads and num_layers or TransformerConfig
# will generate errors.
num_layers=len(hybrid_override_pattern),
num_attention_heads=4,
use_cpu_initialization=True,
)
modules = mamba_stack_spec.submodules
self.block = MambaStack(
transformer_config, modules, hybrid_override_pattern=hybrid_override_pattern
)

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_gpu_forward(self):
block = self.block
block.cuda()
micro_batch_size = 2
sequence_length = 32
hidden_states = torch.ones((sequence_length, micro_batch_size, block.config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones(
(micro_batch_size, 1, sequence_length, sequence_length), dtype=bool
)
attention_mask = attention_mask.cuda()
output = block(hidden_states, attention_mask=attention_mask)
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == block.config.hidden_size
assert output.dtype == torch.float32

def test_layer_types(self):
"""
Make sure that the layer types specified with hybrid_override_pattern
were honored.
"""
block = self.block
layers = block.layers
# Note that this matches the order specified by hybrid_override_pattern in setup_method
assert type(layers[0]) == MambaLayer
assert type(layers[1]) == TransformerLayer
assert type(layers[1].self_attention) == SelfAttention
assert type(layers[2]) == TransformerLayer
assert type(layers[2].mlp) == MLP
47 changes: 47 additions & 0 deletions tests/unit_tests/ssm/test_mamba_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from megatron.core.ssm.mamba_layer import MambaLayer
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class TestMambaLayer:

def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(
hidden_size=256, # The Mamba layer places several constraints on this
# Need to specify num_attention_heads and num_layers or TransformerConfig
# will generate errors.
num_layers=1,
num_attention_heads=1,
use_cpu_initialization=True,
)
modules = mamba_stack_spec.submodules.mamba_layer.submodules
self.layer = MambaLayer(transformer_config, modules)

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_gpu_forward(self):
layer = self.layer
layer.cuda()
micro_batch_size = 2
sequence_length = 32
hidden_states = torch.ones((sequence_length, micro_batch_size, layer.config.hidden_size))
hidden_states = hidden_states.cuda()
attention_mask = torch.ones(
(micro_batch_size, 1, sequence_length, sequence_length), dtype=bool
)
attention_mask = attention_mask.cuda()
output = layer(hidden_states, attention_mask=attention_mask)
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == layer.config.hidden_size
assert output.dtype == torch.float32
50 changes: 50 additions & 0 deletions tests/unit_tests/ssm/test_mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from megatron.core.ssm.mamba_mixer import MambaMixer
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class TestMambaMixer:

def setup_method(self, method):
Utils.initialize_model_parallel(1, 1)
model_parallel_cuda_manual_seed(123)
transformer_config = TransformerConfig(
hidden_size=256, # The Mamba layer places several constraints on this
# Need to specify num_attention_heads and num_layers or TransformerConfig
# will generate errors.
num_layers=1,
num_attention_heads=1,
use_cpu_initialization=True,
)
modules = mamba_stack_spec.submodules.mamba_layer.submodules.mixer.submodules
self.mixer = MambaMixer(transformer_config, modules, transformer_config.hidden_size)
self.mixer_no_mem_eff_path = MambaMixer(
transformer_config, modules, transformer_config.hidden_size, use_mem_eff_path=False
)

def teardown_method(self, method):
Utils.destroy_model_parallel()

@pytest.mark.parametrize("use_mem_eff_path", [True, False])
def test_gpu_forward(self, use_mem_eff_path):
if use_mem_eff_path:
mixer = self.mixer
else:
mixer = self.mixer_no_mem_eff_path
mixer.cuda()
micro_batch_size = 2
sequence_length = 32
hidden_states = torch.ones((sequence_length, micro_batch_size, mixer.config.hidden_size))
hidden_states = hidden_states.cuda()
output, bias = mixer(hidden_states)
assert output.shape[0] == sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == mixer.config.hidden_size
assert output.dtype == torch.float32

0 comments on commit 92ae1d7

Please sign in to comment.