Skip to content

Commit

Permalink
Merge branch 'denliu/fix_moe_parallel_states' into 'main'
Browse files Browse the repository at this point in the history
Fix compatibility error brought by !1940 for NeMo.

See merge request ADLR/megatron-lm!2393
  • Loading branch information
ko3n1g committed Nov 28, 2024
2 parents f3e1afb + 6bd9255 commit 67a50f2
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 10 deletions.
17 changes: 16 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,7 @@ def set_expert_model_parallel_rank(rank):


def get_expert_tensor_parallel_group(check_initialized=True):
"""Get the expert-tensor-parallel group the caller rank belongs to."""
if check_initialized:
assert (
_EXPERT_TENSOR_PARALLEL_GROUP is not None
Expand Down Expand Up @@ -1574,7 +1575,7 @@ def set_expert_tensor_parallel_rank(rank):


def get_expert_tensor_and_model_parallel_group(check_initialized=True):
"""Get the tensor- and expert-parallel group the caller rank belongs to."""
"""Get the expert-tensor and expert-model group the caller rank belongs to."""
if check_initialized:
assert (
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is not None
Expand Down Expand Up @@ -1602,25 +1603,39 @@ def get_expert_tensor_and_model_parallel_rank():


def get_expert_tensor_model_pipeline_parallel_group():
"""Get expert tensor-model-pipeline parallel group."""
assert (
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is not None
), 'Expert tensor-model-pipeline parallel group is not initialized'
return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP


def get_expert_data_parallel_group():
"""Get expert data parallel group."""
assert _EXPERT_DATA_PARALLEL_GROUP is not None, 'Expert data parallel group is not initialized'
return _EXPERT_DATA_PARALLEL_GROUP


def get_data_modulo_expert_parallel_group():
"""[Deprecated] Get expert data parallel group."""
warnings.warn(
"get_data_modulo_expert_parallel_group is deprecated, please use "
"get_expert_data_parallel_group instead.",
DeprecationWarning,
)
return get_expert_data_parallel_group()


def get_expert_data_parallel_group_gloo():
"""Get expert data parallel group-gloo."""
assert (
_EXPERT_DATA_PARALLEL_GROUP_GLOO is not None
), 'Expert data parallel group-gloo is not initialized'
return _EXPERT_DATA_PARALLEL_GROUP_GLOO


def get_expert_data_parallel_rank():
"""Return caller's rank in the expert data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_expert_data_parallel_group())
else:
Expand Down
31 changes: 22 additions & 9 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@

from megatron.core.parallel_state import (
get_expert_model_parallel_group,
get_expert_model_parallel_world_size,
get_expert_tensor_and_model_parallel_group,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_rank,
get_expert_tensor_parallel_world_size,
)
from megatron.core.tensor_parallel import (
all_to_all,
Expand Down Expand Up @@ -50,13 +48,28 @@ def __init__(self, config: TransformerConfig) -> None:
self.config = config
self.shared_experts: Optional[SharedExpertMLP] = None

if torch.distributed.is_available() and torch.distributed.is_initialized():
self.ep_group = get_expert_model_parallel_group()
self.ep_size = get_expert_model_parallel_world_size()
self.tp_group = get_expert_tensor_parallel_group()
self.tp_size = get_expert_tensor_parallel_world_size()
self.tp_rank = get_expert_tensor_parallel_rank()
self.tp_ep_group = get_expert_tensor_and_model_parallel_group()
self.tp_size = config.expert_tensor_parallel_size
self.ep_size = config.expert_model_parallel_size

@property
def ep_group(self):
"""Get expert model parallel group."""
return get_expert_model_parallel_group()

@property
def tp_group(self):
"""Get expert tensor parallel group."""
return get_expert_tensor_parallel_group()

@property
def tp_rank(self):
"""Get expert tensor parallel rank."""
return get_expert_tensor_parallel_rank()

@property
def tp_ep_group(self):
"""Get expert tensor and model parallel group."""
return get_expert_tensor_and_model_parallel_group()

@abstractmethod
def token_permutation(
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,22 @@ def initialize_model_parallel(
**kwargs,
)
Utils.inited = True

@staticmethod
def fake_initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
expert_model_parallel_size=1,
):
"""Used for layer-wise UT as a proxy for NeMo-style intialization."""
ps.set_tensor_model_parallel_world_size(tensor_model_parallel_size)
ps.set_tensor_model_parallel_rank(0)

ps.set_expert_model_parallel_world_size(expert_model_parallel_size)
ps.set_expert_model_parallel_rank(0)
if virtual_pipeline_model_parallel_size is not None:
ps.set_virtual_pipeline_model_parallel_world_size(virtual_pipeline_model_parallel_size)
ps.set_virtual_pipeline_model_parallel_rank(0)

ps.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size)
50 changes: 50 additions & 0 deletions tests/unit_tests/transformer/moe/test_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,55 @@ def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type):
)
Utils.destroy_model_parallel()

@pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"])
@pytest.mark.parametrize("grouped_gemm", [True, False])
@pytest.mark.parametrize("tp_size,ep_size", [(1, 1), (2, 2)])
def test_moe_with_late_initialize(
self, moe_token_dispatcher_type, grouped_gemm, tp_size, ep_size
):
num_moe_experts = 4
hidden_size = 12
transformer_config = TransformerConfig(
num_layers=1,
hidden_size=hidden_size,
num_attention_heads=4,
num_moe_experts=num_moe_experts,
use_cpu_initialization=True,
moe_router_load_balancing_type="aux_loss",
moe_router_topk=2,
moe_aux_loss_coeff=0.01,
add_bias_linear=False,
moe_grouped_gemm=grouped_gemm,
moe_token_dispatcher_type=moe_token_dispatcher_type,
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
sequence_parallel=tp_size > 1,
bf16=True,
params_dtype=torch.bfloat16,
)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm
)

# Fake initialization as NeMo does
Utils.fake_initialize_model_parallel(
tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size
)
moe_layer = MoELayer(
transformer_config, transformer_layer_spec.submodules.mlp.submodules
).cuda()

Utils.initialize_model_parallel(
tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size
)
_set_random_seed(seed_=123, data_parallel_random_init=False)

input_data = torch.randn(
16, 4, hidden_size, device=torch.cuda.current_device(), dtype=torch.bfloat16
)
output = moe_layer(input_data)

Utils.destroy_model_parallel()

def teardown_method(self, method):
Utils.destroy_model_parallel()

0 comments on commit 67a50f2

Please sign in to comment.