diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index f0112b7a04..a008f6bf44 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -1530,6 +1530,7 @@ def set_expert_model_parallel_rank(rank): def get_expert_tensor_parallel_group(check_initialized=True): + """Get the expert-tensor-parallel group the caller rank belongs to.""" if check_initialized: assert ( _EXPERT_TENSOR_PARALLEL_GROUP is not None @@ -1574,7 +1575,7 @@ def set_expert_tensor_parallel_rank(rank): def get_expert_tensor_and_model_parallel_group(check_initialized=True): - """Get the tensor- and expert-parallel group the caller rank belongs to.""" + """Get the expert-tensor and expert-model group the caller rank belongs to.""" if check_initialized: assert ( _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is not None @@ -1602,6 +1603,7 @@ def get_expert_tensor_and_model_parallel_rank(): def get_expert_tensor_model_pipeline_parallel_group(): + """Get expert tensor-model-pipeline parallel group.""" assert ( _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is not None ), 'Expert tensor-model-pipeline parallel group is not initialized' @@ -1609,11 +1611,23 @@ def get_expert_tensor_model_pipeline_parallel_group(): def get_expert_data_parallel_group(): + """Get expert data parallel group.""" assert _EXPERT_DATA_PARALLEL_GROUP is not None, 'Expert data parallel group is not initialized' return _EXPERT_DATA_PARALLEL_GROUP +def get_data_modulo_expert_parallel_group(): + """[Deprecated] Get expert data parallel group.""" + warnings.warn( + "get_data_modulo_expert_parallel_group is deprecated, please use " + "get_expert_data_parallel_group instead.", + DeprecationWarning, + ) + return get_expert_data_parallel_group() + + def get_expert_data_parallel_group_gloo(): + """Get expert data parallel group-gloo.""" assert ( _EXPERT_DATA_PARALLEL_GROUP_GLOO is not None ), 'Expert data parallel group-gloo is not initialized' @@ -1621,6 +1635,7 @@ def get_expert_data_parallel_group_gloo(): def get_expert_data_parallel_rank(): + """Return caller's rank in the expert data parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_expert_data_parallel_group()) else: diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 5db0d19fad..dbd768ddae 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -7,11 +7,9 @@ from megatron.core.parallel_state import ( get_expert_model_parallel_group, - get_expert_model_parallel_world_size, get_expert_tensor_and_model_parallel_group, get_expert_tensor_parallel_group, get_expert_tensor_parallel_rank, - get_expert_tensor_parallel_world_size, ) from megatron.core.tensor_parallel import ( all_to_all, @@ -50,13 +48,28 @@ def __init__(self, config: TransformerConfig) -> None: self.config = config self.shared_experts: Optional[SharedExpertMLP] = None - if torch.distributed.is_available() and torch.distributed.is_initialized(): - self.ep_group = get_expert_model_parallel_group() - self.ep_size = get_expert_model_parallel_world_size() - self.tp_group = get_expert_tensor_parallel_group() - self.tp_size = get_expert_tensor_parallel_world_size() - self.tp_rank = get_expert_tensor_parallel_rank() - self.tp_ep_group = get_expert_tensor_and_model_parallel_group() + self.tp_size = config.expert_tensor_parallel_size + self.ep_size = config.expert_model_parallel_size + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_expert_model_parallel_group() + + @property + def tp_group(self): + """Get expert tensor parallel group.""" + return get_expert_tensor_parallel_group() + + @property + def tp_rank(self): + """Get expert tensor parallel rank.""" + return get_expert_tensor_parallel_rank() + + @property + def tp_ep_group(self): + """Get expert tensor and model parallel group.""" + return get_expert_tensor_and_model_parallel_group() @abstractmethod def token_permutation( diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index 123154bbfe..410350be19 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -102,3 +102,22 @@ def initialize_model_parallel( **kwargs, ) Utils.inited = True + + @staticmethod + def fake_initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + expert_model_parallel_size=1, + ): + """Used for layer-wise UT as a proxy for NeMo-style intialization.""" + ps.set_tensor_model_parallel_world_size(tensor_model_parallel_size) + ps.set_tensor_model_parallel_rank(0) + + ps.set_expert_model_parallel_world_size(expert_model_parallel_size) + ps.set_expert_model_parallel_rank(0) + if virtual_pipeline_model_parallel_size is not None: + ps.set_virtual_pipeline_model_parallel_world_size(virtual_pipeline_model_parallel_size) + ps.set_virtual_pipeline_model_parallel_rank(0) + + ps.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size) diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py index e65e7f2253..591ba4d4ab 100644 --- a/tests/unit_tests/transformer/moe/test_moe_layer.py +++ b/tests/unit_tests/transformer/moe/test_moe_layer.py @@ -69,5 +69,55 @@ def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type): ) Utils.destroy_model_parallel() + @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) + @pytest.mark.parametrize("grouped_gemm", [True, False]) + @pytest.mark.parametrize("tp_size,ep_size", [(1, 1), (2, 2)]) + def test_moe_with_late_initialize( + self, moe_token_dispatcher_type, grouped_gemm, tp_size, ep_size + ): + num_moe_experts = 4 + hidden_size = 12 + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + use_cpu_initialization=True, + moe_router_load_balancing_type="aux_loss", + moe_router_topk=2, + moe_aux_loss_coeff=0.01, + add_bias_linear=False, + moe_grouped_gemm=grouped_gemm, + moe_token_dispatcher_type=moe_token_dispatcher_type, + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + sequence_parallel=tp_size > 1, + bf16=True, + params_dtype=torch.bfloat16, + ) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + ) + + # Fake initialization as NeMo does + Utils.fake_initialize_model_parallel( + tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size + ) + moe_layer = MoELayer( + transformer_config, transformer_layer_spec.submodules.mlp.submodules + ).cuda() + + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size + ) + _set_random_seed(seed_=123, data_parallel_random_init=False) + + input_data = torch.randn( + 16, 4, hidden_size, device=torch.cuda.current_device(), dtype=torch.bfloat16 + ) + output = moe_layer(input_data) + + Utils.destroy_model_parallel() + def teardown_method(self, method): Utils.destroy_model_parallel()