Skip to content

Commit

Permalink
Merge branch 'denliu/moe_parallel_states' into 'main'
Browse files Browse the repository at this point in the history
MoE parallel folding: separate MoE parallel states from dense

See merge request ADLR/megatron-lm!1940
  • Loading branch information
ko3n1g committed Nov 23, 2024
2 parents a9d040c + 7f22e21 commit d392f9c
Show file tree
Hide file tree
Showing 34 changed files with 1,850 additions and 707 deletions.
6 changes: 2 additions & 4 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _allocate_buffers_for_parameters(
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
_allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
parallel_state.get_expert_data_parallel_group(),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)
Expand Down Expand Up @@ -440,9 +440,7 @@ def broadcast_params(self):
is_expert_parallel = not getattr(param, 'allreduce', True)

if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group(
with_context_parallel=True
)
data_parallel_group = parallel_state.get_expert_data_parallel_group()
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
Expand Down
76 changes: 43 additions & 33 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
from torch import Tensor
from torch.nn.parameter import Parameter

from megatron.core import ModelParallelConfig, parallel_state
from megatron.core import ModelParallelConfig
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_expert_data_parallel_rank,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_rank,
get_expert_tensor_parallel_world_size,
get_hierarchical_context_parallel_groups,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -162,19 +167,23 @@ def __init__(
extra_kwargs["ub_name"] = tp_comm_buffer_name

self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert and self.expert_parallel:
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name

# Disable communications in TE when using SP or EP by making TE agnostic of model parallel.
tp_size = self.config.tensor_model_parallel_size
tp_group = get_tensor_model_parallel_group(check_initialized=False)
if is_expert and (self.config.sequence_parallel or self.expert_parallel):
if self.config.moe_extended_tp:
tp_size = get_tensor_and_expert_parallel_world_size()
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
Expand Down Expand Up @@ -418,9 +427,13 @@ def __init__(
tp_comm_buffer_name=tp_comm_buffer_name,
)

world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
Expand Down Expand Up @@ -492,9 +505,13 @@ def __init__(
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
input_size_per_partition = divide(input_size, world_size)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
Expand Down Expand Up @@ -760,19 +777,19 @@ def __init__(
extra_kwargs["ub_name"] = tp_comm_buffer_name

self.expert_parallel = self.config.expert_model_parallel_size > 1
if self.expert_parallel:
if is_expert:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()

# For MoE models, the comms between TP and EP group is explicitly handled by
# MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
self.explicit_expert_comm = is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
tp_group = get_tensor_model_parallel_group(check_initialized=False)
if self.explicit_expert_comm and config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
# The comms between TP and EP group is explicitly handled by MoE token dispatcher.
# So we disable comms by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
Expand Down Expand Up @@ -922,12 +939,8 @@ def _sharded_state_dict_grouped(
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_gemms
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_gemms
)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
extra_states = self._split_extra_state(full_state_dict['_extra_state'])
for gemm_idx in range(self.num_gemms):
Expand Down Expand Up @@ -964,10 +977,7 @@ def _sharded_state_dict_grouped(
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank())
return sharded_state_dict

class TEColumnParallelGroupedLinear(TEGroupedLinear):
Expand Down
12 changes: 8 additions & 4 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class ModelParallelConfig:
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

expert_tensor_parallel_size: Optional[int] = None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""

moe_extended_tp: bool = False
"""Alternative parallelization strategy for expert parallelism. Instead of distributing experts
across expert_model_parallel_size, each expert is sharded along extendended tensor parallel
domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing
problem with MOE training.
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""

###################
Expand Down Expand Up @@ -341,6 +342,9 @@ def __post_init__(self):
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")

if self.expert_tensor_parallel_size is None:
self.expert_tensor_parallel_size = self.tensor_model_parallel_size

if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError(
Expand Down
18 changes: 7 additions & 11 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,23 +419,19 @@ def get_megatron_optimizer(
buffer_name='expert_parallel_buffers',
)
if len(moe_param_groups) > 0:
model_parallel_world_size = torch.distributed.get_world_size(mpu.get_model_parallel_group())
expert_parallel_rank = mpu.get_expert_model_parallel_rank()
model_parallel_rank = torch.distributed.get_rank(
mpu.get_expert_tensor_model_pipeline_parallel_group()
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunks,
param_groups=moe_param_groups,
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_model_parallel_group(with_expert_parallel=True),
data_parallel_group=mpu.get_data_modulo_expert_parallel_group(
with_context_parallel=True
),
data_parallel_group_gloo=mpu.get_data_modulo_expert_parallel_group_gloo(
with_context_parallel=True
),
data_parallel_group_idx=expert_parallel_rank * model_parallel_world_size
+ model_parallel_rank,
model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(),
data_parallel_group=mpu.get_expert_data_parallel_group(),
data_parallel_group_gloo=mpu.get_expert_data_parallel_group_gloo(),
data_parallel_group_idx=model_parallel_rank,
)
)

Expand Down
Loading

0 comments on commit d392f9c

Please sign in to comment.