From 6e4e9df20bdf8fadc4ecb79a51944adfde38ab99 Mon Sep 17 00:00:00 2001 From: Zijie Yan Date: Wed, 11 Sep 2024 06:45:45 -0700 Subject: [PATCH] ADLR/megatron-lm!2088 - Add MoE interface tests and move other tests to internal --- .../transformer/moe/test_aux_loss.py | 2 + .../transformer/moe/test_grouped_mlp.py | 8 ++ .../transformer/moe/test_moe_layer.py | 73 +++++++++++++++++++ .../transformer/moe/test_routers.py | 2 + .../transformer/moe/test_sequential_mlp.py | 6 ++ 5 files changed, 91 insertions(+) create mode 100644 tests/unit_tests/transformer/moe/test_moe_layer.py diff --git a/tests/unit_tests/transformer/moe/test_aux_loss.py b/tests/unit_tests/transformer/moe/test_aux_loss.py index 2e26f01551..2b7b2e109b 100644 --- a/tests/unit_tests/transformer/moe/test_aux_loss.py +++ b/tests/unit_tests/transformer/moe/test_aux_loss.py @@ -57,6 +57,7 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize( "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] ) @@ -75,6 +76,7 @@ def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): container.aux_loss_test(self.input, self.baseline_grad) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize( "tp_size,ep_size,cp_size", [(8, 1, 1), (4, 2, 1), (1, 1, 8), (2, 1, 4), (2, 2, 2)] ) diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 757be59232..c95e184897 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -92,6 +92,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal def test_constructor(self): assert isinstance(self.sequential_mlp, MoELayer) assert isinstance(self.grouped_mlp, MoELayer) @@ -130,6 +131,7 @@ def test_constructor(self): self.grouped_mlp.experts.weight1.shape == self.grouped_mlp.experts.weight2.t().shape ) + @pytest.mark.internal def test_weight_init_value_the_same(self): gmm_w1 = self.grouped_mlp.experts.weight1.view(self.num_experts, -1, self.hidden_size) gmm_w2 = self.grouped_mlp.experts.weight2.view(self.num_experts, self.hidden_size, -1) @@ -153,6 +155,7 @@ def test_weight_init_value_the_same(self): assert torch.equal(gmm_expert2_fc2, smm_expert2_fc2) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.skipif( not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.', @@ -175,6 +178,7 @@ def test_gpu_forward(self): # assert torch.equal(output_smm, output_gmm) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.skipif( not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.', @@ -193,6 +197,7 @@ def test_gpu_forward_with_no_tokens_allocated(self): assert str(e) == "Input batch_sizes should not be all zeros!" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.skipif( not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.', @@ -274,6 +279,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal def test_constructor(self): assert isinstance(self.sequential_mlp, MoELayer) assert isinstance(self.grouped_mlp, MoELayer) @@ -308,6 +314,7 @@ def test_constructor(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal def test_gpu_forward_backward(self): self.sequential_mlp.cuda() self.grouped_mlp.cuda() @@ -350,6 +357,7 @@ def test_gpu_forward_backward(self): torch.testing.assert_close(smm_result, gmm_result) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal def test_gpu_forward_backward_with_no_tokens_allocated(self): """Test the case when no token is allocated for groupedGEMM kernels.""" self.grouped_mlp.cuda() diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py new file mode 100644 index 0000000000..e65e7f2253 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_moe_layer.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import pytest +import torch + +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.router import Router +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training.initialize import _set_random_seed +from tests.unit_tests.test_utilities import Utils + + +class TestMoELayerInit: + def setup_method(self, method): + pass + + @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) + @pytest.mark.parametrize("num_moe_experts", [1, 2]) + @pytest.mark.parametrize("grouped_gemm", [True, False]) + def test_te_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_gemm): + Utils.initialize_model_parallel(1, 1) + _set_random_seed(seed_=123, data_parallel_random_init=False) + self.transformer_config = TransformerConfig( + num_layers=1, + hidden_size=12, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + use_cpu_initialization=True, + moe_token_dispatcher_type=moe_token_dispatcher_type, + moe_router_topk=2, + moe_aux_loss_coeff=0.01, + moe_grouped_gemm=grouped_gemm, + add_bias_linear=False, + ) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + ) + moe_layer = MoELayer( + self.transformer_config, transformer_layer_spec.submodules.mlp.submodules + ) + Utils.destroy_model_parallel() + + @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) + @pytest.mark.parametrize("num_moe_experts", [1, 2]) + def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type): + Utils.initialize_model_parallel(1, 1) + _set_random_seed(seed_=123, data_parallel_random_init=False) + num_moe_experts = 4 + self.transformer_config = TransformerConfig( + num_layers=1, + hidden_size=12, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + use_cpu_initialization=True, + moe_router_load_balancing_type="aux_loss", + moe_router_topk=2, + moe_aux_loss_coeff=0.01, + add_bias_linear=False, + ) + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, moe_grouped_gemm=False + ) + moe_layer = MoELayer( + self.transformer_config, transformer_layer_spec.submodules.mlp.submodules + ) + Utils.destroy_model_parallel() + + def teardown_method(self, method): + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index b1d07d054a..c1633834b6 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -45,6 +45,7 @@ def test_constructor(self): assert num_weights == 12 * 4, num_weights @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) def test_router_forward(self, moe_router_pre_softmax): with torch.no_grad(): @@ -56,6 +57,7 @@ def test_router_forward(self, moe_router_pre_softmax): scores, indices = self.router(hidden_states) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal def test_aux_loss(self): self.sequential_mlp = self.sequential_mlp.cuda() diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py index df1002677c..40a0caf31a 100644 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py @@ -50,12 +50,14 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal def test_constructor(self): assert isinstance(self.sequential_mlp, MoELayer) num_weights = sum([p.numel() for p in self.sequential_mlp.parameters()]) assert num_weights == 3696 + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_gpu_forward(self): sequential_mlp = self.sequential_mlp @@ -118,6 +120,7 @@ def setup_method(self, method): te_version < packaging.version.Version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", ) + @pytest.mark.internal def test_constructor(self): for i in range(self.num_local_experts): assert torch.equal( @@ -133,6 +136,7 @@ def test_constructor(self): te_version < packaging.version.Version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", ) + @pytest.mark.internal def test_gpu_forward(self): self.local_sequential_mlp.cuda() self.te_sequential_mlp.cuda() @@ -154,6 +158,7 @@ def test_gpu_forward(self): te_version < packaging.version.Version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", ) + @pytest.mark.internal def test_gpu_forward_with_one_local_expert(self): model_parallel_cuda_manual_seed(123) local_sequential_mlp = SequentialMLP(1, self.transformer_config, self.local_mlp_spec) @@ -177,6 +182,7 @@ def test_gpu_forward_with_one_local_expert(self): te_version < packaging.version.Version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", ) + @pytest.mark.internal def test_gpu_forward_with_no_tokens_allocated(self): self.local_sequential_mlp.cuda() self.te_sequential_mlp.cuda()