diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 5c9e1df842..300f3c71b9 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -232,7 +232,7 @@ def _allocate_buffers_for_parameters( self.expert_parallel_buffers, self.expert_parallel_bucket_groups = ( _allocate_buffers_for_parameters( expert_parallel_params, - parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True), + parallel_state.get_expert_data_parallel_group(), gradient_scaling_factor=expert_gradient_scaling_factor, ) ) @@ -440,9 +440,7 @@ def broadcast_params(self): is_expert_parallel = not getattr(param, 'allreduce', True) if is_expert_parallel: - data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group( - with_context_parallel=True - ) + data_parallel_group = parallel_state.get_expert_data_parallel_group() else: data_parallel_group = parallel_state.get_data_parallel_group( with_context_parallel=True diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 9b7ecf3ffd..de757a461b 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -13,14 +13,19 @@ from torch import Tensor from torch.nn.parameter import Parameter -from megatron.core import ModelParallelConfig, parallel_state +from megatron.core import ModelParallelConfig from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( get_context_parallel_global_ranks, get_context_parallel_group, + get_expert_data_parallel_rank, + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, get_hierarchical_context_parallel_groups, - get_tensor_and_expert_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -162,19 +167,23 @@ def __init__( extra_kwargs["ub_name"] = tp_comm_buffer_name self.expert_parallel = self.config.expert_model_parallel_size > 1 - if is_expert and self.expert_parallel: + if is_expert: rng_tracker_name = get_expert_parallel_rng_tracker_name() else: rng_tracker_name = None if is_te_min_version("1.7.0"): extra_kwargs["rng_tracker_name"] = rng_tracker_name - # Disable communications in TE when using SP or EP by making TE agnostic of model parallel. - tp_size = self.config.tensor_model_parallel_size - tp_group = get_tensor_model_parallel_group(check_initialized=False) - if is_expert and (self.config.sequence_parallel or self.expert_parallel): - if self.config.moe_extended_tp: - tp_size = get_tensor_and_expert_parallel_world_size() + # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() + else: + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + if explicit_expert_comm: if parallel_mode == "column": output_size = divide(output_size, tp_size) elif parallel_mode == "row": @@ -418,9 +427,13 @@ def __init__( tp_comm_buffer_name=tp_comm_buffer_name, ) - world_size = get_tensor_model_parallel_world_size() - rank = get_tensor_model_parallel_rank() if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() output_size_per_partition = divide(output_size, world_size) _ = _initialize_affine_weight_cpu( self.weight, @@ -492,9 +505,13 @@ def __init__( is_expert=is_expert, tp_comm_buffer_name=tp_comm_buffer_name, ) - world_size = get_tensor_model_parallel_world_size() - rank = get_tensor_model_parallel_rank() if config.use_cpu_initialization: + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() + else: + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() input_size_per_partition = divide(input_size, world_size) self.master_weight = _initialize_affine_weight_cpu( self.weight, @@ -760,19 +777,19 @@ def __init__( extra_kwargs["ub_name"] = tp_comm_buffer_name self.expert_parallel = self.config.expert_model_parallel_size > 1 - if self.expert_parallel: + if is_expert: extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name() - # For MoE models, the comms between TP and EP group is explicitly handled by - # MoE token dispatcher. So we disable comms by making TE agnostic of model parallel. - self.explicit_expert_comm = is_expert and ( - config.tensor_model_parallel_size > 1 or self.expert_parallel - ) - tp_group = get_tensor_model_parallel_group(check_initialized=False) - if self.explicit_expert_comm and config.moe_extended_tp: - tp_size = parallel_state.get_tensor_and_expert_parallel_world_size() + # The comms between TP and EP group is explicitly handled by MoE token dispatcher. + # So we disable comms by making TE agnostic of model parallel. + if is_expert: + tp_group = get_expert_tensor_parallel_group(check_initialized=False) + tp_size = get_expert_tensor_parallel_world_size() else: - tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_group = get_tensor_model_parallel_group(check_initialized=False) + tp_size = get_tensor_model_parallel_world_size() + self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + if self.explicit_expert_comm: if parallel_mode == "column": output_size = divide(output_size, tp_size) @@ -922,12 +939,8 @@ def _sharded_state_dict_grouped( """ sharded_state_dict = {} full_state_dict = self.state_dict(prefix='', keep_vars=True) - num_global_experts = ( - parallel_state.get_expert_model_parallel_world_size() * self.num_gemms - ) - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.num_gemms - ) + num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms + local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms ep_axis = len(sharded_offsets) extra_states = self._split_extra_state(full_state_dict['_extra_state']) for gemm_idx in range(self.num_gemms): @@ -964,10 +977,7 @@ def _sharded_state_dict_grouped( assert ( len(replica_id) == 3 ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' - sh_ten.replica_id = ( - *replica_id[:2], - parallel_state.get_data_modulo_expert_parallel_rank(), - ) + sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank()) return sharded_state_dict class TEColumnParallelGroupedLinear(TEGroupedLinear): diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index ff8f45156b..46a03f6d6d 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -50,11 +50,12 @@ class ModelParallelConfig: expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" + expert_tensor_parallel_size: Optional[int] = None + """Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks.""" + moe_extended_tp: bool = False - """Alternative parallelization strategy for expert parallelism. Instead of distributing experts - across expert_model_parallel_size, each expert is sharded along extendended tensor parallel - domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing - problem with MOE training. + """NOTE: Deprecated from MCore v0.10. This flag is ignored. + Its functionality is replaced by expert_tensor_parallel_size. """ ################### @@ -341,6 +342,9 @@ def __post_init__(self): if self.tensor_model_parallel_size <= 1: raise ValueError("Can not use sequence paralllelism without tensor parallelism") + if self.expert_tensor_parallel_size is None: + self.expert_tensor_parallel_size = self.tensor_model_parallel_size + if self.pipeline_model_parallel_size > 1: if self.pipeline_dtype is None: raise ValueError( diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 7c61bbb3ba..71b1987c88 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -419,23 +419,19 @@ def get_megatron_optimizer( buffer_name='expert_parallel_buffers', ) if len(moe_param_groups) > 0: - model_parallel_world_size = torch.distributed.get_world_size(mpu.get_model_parallel_group()) - expert_parallel_rank = mpu.get_expert_model_parallel_rank() + model_parallel_rank = torch.distributed.get_rank( + mpu.get_expert_tensor_model_pipeline_parallel_group() + ) optimizers.append( _get_megatron_optimizer_based_on_param_groups( config, model_chunks=model_chunks, param_groups=moe_param_groups, per_model_buffers=moe_buffers, - model_parallel_group=mpu.get_model_parallel_group(with_expert_parallel=True), - data_parallel_group=mpu.get_data_modulo_expert_parallel_group( - with_context_parallel=True - ), - data_parallel_group_gloo=mpu.get_data_modulo_expert_parallel_group_gloo( - with_context_parallel=True - ), - data_parallel_group_idx=expert_parallel_rank * model_parallel_world_size - + model_parallel_rank, + model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(), + data_parallel_group=mpu.get_expert_data_parallel_group(), + data_parallel_group_gloo=mpu.get_expert_data_parallel_group_gloo(), + data_parallel_group_idx=model_parallel_rank, ) ) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 500c06e17a..167be12f19 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -20,7 +20,6 @@ # Model parallel group (both intra- and pipeline) that the current rank belongs to. _MODEL_PARALLEL_GROUP = None # Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to. -_MODEL_AND_EXPERT_PARALLEL_GROUP = None # Embedding group. _EMBEDDING_GROUP = None # Position embedding group. @@ -31,14 +30,31 @@ # tensor model parallel group and data parallel group combined # used for fp8 and moe training _TENSOR_AND_DATA_PARALLEL_GROUP = None -# Expert parallel group that the current rank belongs to. -_EXPERT_MODEL_PARALLEL_GROUP = None -_TENSOR_AND_EXPERT_PARALLEL_GROUP = None -_DATA_MODULO_EXPERT_PARALLEL_GROUP = None -_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None -_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None -_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None +### Expert-related parallel states +# Naming convention: +# _EXPERT prefix in group name means it's used for expert layer in MoE models. +# _EXPERT_MODEL denotes expert parallelism which splits number of experts across the group. +# _EXPERT_TENSOR denotes tensor parallelism of expert which splits tensor across the group. +# _EXPERT_DATA denotes data parallelism of expert which replicates weight across the group. + +# Expert model parallel group that current rank belongs to. +_EXPERT_MODEL_PARALLEL_GROUP = None +# Expert tensor parallel group that current rank belongs to. +_EXPERT_TENSOR_PARALLEL_GROUP = None +# Expert tensor and model combined parallel group +_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = None +# Expert tensor, model, pipeline combined parallel group +_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = None +# Expert data parallel group +_EXPERT_DATA_PARALLEL_GROUP = None +_EXPERT_DATA_PARALLEL_GROUP_GLOO = None +# Parallel state values changed on the fly +_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_EXPERT_MODEL_PARALLEL_RANK = None +_MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = None +_MPU_EXPERT_TENSOR_PARALLEL_RANK = None +### End of expert related parallel states _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None @@ -49,12 +65,10 @@ # These values enable us to change the mpu sizes on the fly. _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None _MPU_DATA_PARALLEL_WORLD_SIZE = None _MPU_DATA_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None -_MPU_EXPERT_MODEL_PARALLEL_RANK = None # A list of ranks that have a copy of the embedding. _EMBEDDING_GLOBAL_RANKS = None @@ -183,15 +197,15 @@ def inner_product(a: List[int], b: List[int]) -> int: return sum([x * y for x, y in zip(a, b)]) def decompose(index, shape, stride=None): - ''' + """ This function solve the math problem below: There is an equation: index = sum(idx[i] * stride[i]) And given the value of index, stride. Return the idx. - This function will used to get the pp/dp/pp_rank + This function will be used to get the pp/dp/pp_rank from group_index and rank_in_group. - ''' + """ if stride is None: stride = prefix_product(shape) idx = [(index // d) % s for s, d in zip(shape, stride)] @@ -268,13 +282,18 @@ class RankGenerator(object): def __init__( self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0 ) -> None: + assert ( + ep == 1 or cp == 1 + ), "Both EP and CP > 1 in not allow in one rank generator. \ + CP is only included in default RankGenerator, and EP only in expert RankGenerator." + self.tp = tp self.ep = ep self.dp = dp self.pp = pp self.cp = cp self.rank_offset = rank_offset - self.world_size = tp * dp * pp * cp + self.world_size = tp * dp * pp * cp * ep self.name_to_size = { "tp": self.tp, @@ -286,10 +305,6 @@ def __init__( self.order = order order = order.lower() - if 'ep' in order: - if 'ep-dp' not in order and 'dp-ep' not in order: - raise RuntimeError(f"The ep and dp must be adjacent in order ({self.order}).") - for name in self.name_to_size.keys(): if name not in order and self.name_to_size[name] != 1: raise RuntimeError( @@ -299,20 +314,11 @@ def __init__( elif name not in order: order = order + '-' + name - self.order_w_ep = order - self.order_wo_ep = '-'.join([token for token in order.split('-') if token != 'ep']) - self.ordered_size_wo_ep = [] - self.ordered_size_w_ep = [] + self.order = order + self.ordered_size = [] for token in order.split('-'): - if token == 'dp': - self.ordered_size_w_ep.append(self.dp // self.ep) - self.ordered_size_wo_ep.append(self.dp) - elif token == 'ep': - self.ordered_size_w_ep.append(self.ep) - else: - self.ordered_size_w_ep.append(self.name_to_size[token]) - self.ordered_size_wo_ep.append(self.name_to_size[token]) + self.ordered_size.append(self.name_to_size[token]) def get_mask(self, order: str, token: str): """Create a mask for the specified tokens based on the given order. @@ -329,7 +335,7 @@ def get_mask(self, order: str, token: str): mask[ordered_token.index(t)] = True return mask - def get_ranks(self, token, independent_ep=False): + def get_ranks(self, token): """Get rank group by input token. Args: @@ -338,22 +344,9 @@ def get_ranks(self, token, independent_ep=False): to obtain multiple parallel types, we can use a hyphen '-' to separate them. For example, if we want to obtain the TP_DP group, the token should be 'tp-dp'. - - independent_ep (bool: True): - This flag controls whether we treat EP and DP independently. - EP shares ranks with DP, if we want to get ranks related to - EP, we should set the flag. For example, get_ranks('dp', True) - will get DP modulo EP group, and get_ranks('dp', False) will - get full DP group. """ - if independent_ep: - parallel_size = self.ordered_size_w_ep - order = self.order_w_ep - else: - parallel_size = self.ordered_size_wo_ep - order = self.order_wo_ep - mask = self.get_mask(order, token) - ranks = generate_masked_orthogonal_rank_groups(self.world_size, parallel_size, mask) + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) if self.rank_offset > 0: for rank_group in ranks: for i in range(len(rank_group)): @@ -394,6 +387,7 @@ def initialize_model_parallel( context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: Optional[int] = None, nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: int = 30, order: str = "tp-cp-ep-dp-pp", @@ -475,6 +469,9 @@ def initialize_model_parallel( The number of Mixture of Experts parallel GPUs in each expert parallel group. + expert_tensor_parallel_size (int, default = tp_size): + The number of GPUs to split individual tensors of expert. + nccl_communicator_config_path (str, default = None): Path to the yaml file of NCCL communicator configurations. `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set @@ -569,12 +566,6 @@ def initialize_model_parallel( data_parallel_size: int = world_size // total_model_size - if data_parallel_size % expert_model_parallel_size != 0: - raise RuntimeError( - f"data_parallel_size ({data_parallel_size}) is not divisible by " - "expert_model_parallel_size " - ) - encoder_world_size = encoder_model_size * data_parallel_size decoder_world_size = decoder_model_size * data_parallel_size @@ -626,7 +617,7 @@ def initialize_model_parallel( decoder_rank_generator = RankGenerator( tp=tensor_model_parallel_size, - ep=expert_model_parallel_size, + ep=1, dp=data_parallel_size, pp=pipeline_model_parallel_size, cp=context_parallel_size, @@ -634,13 +625,45 @@ def initialize_model_parallel( rank_offset=encoder_world_size, ) - def generator_wrapper(group_type, **kwargs): + # Build expert rank generator + if expert_tensor_parallel_size is None: + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size + ) + expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size + if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0: + raise RuntimeError( + f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + + # TODO: support expert specific ordering + expert_decoder_rank_generator = RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order=order, + rank_offset=encoder_world_size, + ) + + assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks( + "pp" + ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \ + but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}" + + def generator_wrapper(group_type, is_expert=False, **kwargs): """The `RankGenerator` class produces a hyper-rectangle for a given set of tensor, pipeline, data, expert, and context parallelism. If we have an encoder, in addition to the default decoder, we essentially instantiate two `RankGenerator` classes to construct the parallelism for each module separately, and we then have to stitch them together for the right groups. For now, this means pp and tp-pp.""" - d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if is_expert: + d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs) + else: + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + if encoder_rank_generator is None: for x in d_ranks: yield x @@ -747,18 +770,6 @@ def generator_wrapper(group_type, **kwargs): if rank in ranks: _MODEL_PARALLEL_GROUP = group - # Build the model-parallel groups with expert parallel - global _MODEL_AND_EXPERT_PARALLEL_GROUP - assert ( - _MODEL_AND_EXPERT_PARALLEL_GROUP is None - ), 'model and expert parallel group is already initialized' - for ranks in generator_wrapper('tp-ep-pp', independent_ep=True): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=get_nccl_options('mp_exp', nccl_comm_cfgs) - ) - if rank in ranks: - _MODEL_AND_EXPERT_PARALLEL_GROUP = group - # Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS @@ -849,62 +860,68 @@ def generator_wrapper(group_type, **kwargs): if rank in ranks: _TENSOR_AND_CONTEXT_PARALLEL_GROUP = group - # Build the tensor + expert parallel groups + ### Expert-related parallel groups initialization + # Build the expert model parallel group global _EXPERT_MODEL_PARALLEL_GROUP assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized' - global _TENSOR_AND_EXPERT_PARALLEL_GROUP - assert ( - _TENSOR_AND_EXPERT_PARALLEL_GROUP is None - ), 'Tensor + expert parallel group is already initialized' - global _DATA_MODULO_EXPERT_PARALLEL_GROUP - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP is None - ), 'Data modulo expert group is already initialized' - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP + for ranks in generator_wrapper('ep', is_expert=True): + group = torch.distributed.new_group( + ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs) + ) + if rank in ranks: + _EXPERT_MODEL_PARALLEL_GROUP = group + + # Build the expert tensor parallel group + global _EXPERT_TENSOR_PARALLEL_GROUP assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is None - ), 'Data modulo expert group with context parallel is already initialized' - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO + _EXPERT_TENSOR_PARALLEL_GROUP is None + ), 'Expert tensor model parallel group is already initialized' + for ranks in generator_wrapper('tp', is_expert=True): + group = torch.distributed.new_group( + ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs) + ) + if rank in ranks: + _EXPERT_TENSOR_PARALLEL_GROUP = group - for ranks in generator_wrapper('tp-ep', independent_ep=True): + # Build the tensor + expert parallel groups + global _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP + assert ( + _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is None + ), 'Expert tensor + model parallel group is already initialized' + for ranks in generator_wrapper('tp-ep', is_expert=True): group = torch.distributed.new_group( ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs) ) if rank in ranks: - _TENSOR_AND_EXPERT_PARALLEL_GROUP = group + _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = group - for ranks in generator_wrapper('ep', independent_ep=True): + # Build the expert+tensor+pipeline parallel groups + global _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP + assert ( + _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is None + ), 'The expert_tensor_model_pipeline parallel group is already initialized' + for ranks in generator_wrapper('tp-ep-pp', is_expert=True): group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs) + ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs) ) if rank in ranks: - _EXPERT_MODEL_PARALLEL_GROUP = group + _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = group + + # Build the expert data parallel group + global _EXPERT_DATA_PARALLEL_GROUP + assert _EXPERT_DATA_PARALLEL_GROUP is None, 'Expert data group is already initialized' + global _EXPERT_DATA_PARALLEL_GROUP_GLOO + assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, 'Expert data group-gloo is already initialized' - for ranks in generator_wrapper('dp', independent_ep=True): + for ranks in generator_wrapper('dp', is_expert=True): group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs) + ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs) ) group_gloo = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _DATA_MODULO_EXPERT_PARALLEL_GROUP = group - _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo - - for ranks in generator_wrapper('dp-cp', independent_ep=True): - # Lazy initialization of the group - if get_context_parallel_world_size() > 1: - group = torch.distributed.new_group( - ranks, - timeout=timeout, - pg_options=get_nccl_options('dp_modulo_exp_cp', nccl_comm_cfgs), - ) - group_gloo = torch.distributed.new_group(ranks, backend="gloo") - else: - group = _DATA_MODULO_EXPERT_PARALLEL_GROUP - group_gloo = _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO - if rank in ranks: - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = group - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = group_gloo + _EXPERT_DATA_PARALLEL_GROUP = group + _EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo + ### End of expert related parallel groups initialization # Initialize global memory buffer # This isn't really "parallel state" but there isn't another good place to @@ -939,13 +956,8 @@ def model_parallel_is_initialized(): return True -def get_model_parallel_group(with_expert_parallel=False): +def get_model_parallel_group(): """Get the model-parallel group the caller rank belongs to.""" - if with_expert_parallel: - assert ( - _MODEL_AND_EXPERT_PARALLEL_GROUP is not None - ), 'model parallel group is not initialized' - return _MODEL_AND_EXPERT_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' return _MODEL_PARALLEL_GROUP @@ -1074,56 +1086,6 @@ def get_tensor_and_context_parallel_group(): return _TENSOR_AND_CONTEXT_PARALLEL_GROUP -def get_expert_model_parallel_group(): - """Get the expert-model-parallel group the caller rank belongs to.""" - assert ( - _EXPERT_MODEL_PARALLEL_GROUP is not None - ), 'expert model parallel group is not initialized' - return _EXPERT_MODEL_PARALLEL_GROUP - - -def get_tensor_and_expert_parallel_group(): - """Get the tensor- and expert-parallel group the caller rank belongs to.""" - assert ( - _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None - ), 'tensor and expert parallel group is not initialized' - return _TENSOR_AND_EXPERT_PARALLEL_GROUP - - -def get_data_modulo_expert_parallel_group(with_context_parallel=False): - """Get the data-modulo-expert-parallel group the caller rank belongs to.""" - if with_context_parallel: - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is not None - ), 'data modulo expert parallel group with context parallel is not initialized' - return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP - else: - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP is not None - ), 'data modulo expert parallel group is not initialized' - return _DATA_MODULO_EXPERT_PARALLEL_GROUP - - -def get_data_modulo_expert_parallel_group_gloo(with_context_parallel=False): - """Get the Gloo data-modulo-expert-parallel group the caller rank belongs to.""" - if with_context_parallel: - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO is not None - ), 'data modulo expert parallel group-gloo with context parallel is not initialized' - return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO - else: - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None - ), 'data modulo expert parallel group-gloo is not initialized' - return _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO - - -def set_expert_model_parallel_world_size(world_size): - """Sets the expert-model-parallel world size.""" - global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE - _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size - - def set_tensor_model_parallel_world_size(world_size): """Set the tensor-model-parallel size""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE @@ -1168,12 +1130,6 @@ def get_pipeline_model_parallel_world_size(): return torch.distributed.get_world_size(group=pp_group) -def set_expert_model_parallel_rank(rank): - """Set expert-model-parallel rank.""" - global _MPU_EXPERT_MODEL_PARALLEL_RANK - _MPU_EXPERT_MODEL_PARALLEL_RANK = rank - - def set_tensor_model_parallel_rank(rank): """Set tensor-model-parallel rank.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK @@ -1518,30 +1474,30 @@ def get_tensor_and_context_parallel_rank(): return 0 +### Expert-related parallel states functions +def get_expert_model_parallel_group(check_initialized=True): + """Get the expert-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert ( + _EXPERT_MODEL_PARALLEL_GROUP is not None + ), 'expert model parallel group is not initialized' + return _EXPERT_MODEL_PARALLEL_GROUP + + def get_expert_model_parallel_world_size(): """Return world size for the expert-model-parallel group.""" if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE if torch.distributed.is_available() and torch.distributed.is_initialized(): - tensor_and_expert_parallel_world_size = torch.distributed.get_world_size( - group=get_tensor_and_expert_parallel_group() - ) - return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size() + return torch.distributed.get_world_size(group=get_expert_model_parallel_group()) else: return 0 -def get_tensor_and_expert_parallel_world_size(): - """Return world size for the expert model parallel group times model parallel group. - Currently, each expert will also be distributed across TP group by default. - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - tensor_and_expert_parallel_world_size = torch.distributed.get_world_size( - group=get_tensor_and_expert_parallel_group() - ) - return tensor_and_expert_parallel_world_size - else: - return 0 +def set_expert_model_parallel_world_size(world_size): + """Sets the expert-model-parallel world size.""" + global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size def get_expert_model_parallel_rank(): @@ -1549,32 +1505,118 @@ def get_expert_model_parallel_rank(): if _MPU_EXPERT_MODEL_PARALLEL_RANK is not None: return _MPU_EXPERT_MODEL_PARALLEL_RANK if torch.distributed.is_available() and torch.distributed.is_initialized(): - tensor_and_expert_parallel_rank = torch.distributed.get_rank( - group=get_tensor_and_expert_parallel_group() - ) - return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size() + return torch.distributed.get_rank(group=get_expert_model_parallel_group()) else: return 0 -def get_data_modulo_expert_parallel_rank(with_context_parallel=False): - """Return caller's rank in the context-parallel group.""" +def set_expert_model_parallel_rank(rank): + """Set expert-model-parallel rank.""" + global _MPU_EXPERT_MODEL_PARALLEL_RANK + _MPU_EXPERT_MODEL_PARALLEL_RANK = rank + + +def get_expert_tensor_parallel_group(check_initialized=True): + if check_initialized: + assert ( + _EXPERT_TENSOR_PARALLEL_GROUP is not None + ), 'Expert tensor parallel group is not initialized' + return _EXPERT_TENSOR_PARALLEL_GROUP + + +def get_expert_tensor_parallel_world_size(): + """Return world size for the expert tensor parallel group.""" + global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE + if _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE is not None: + return _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE + # Use tensor parallel group world size for backward compability otherwise + if not _EXPERT_TENSOR_PARALLEL_GROUP: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + else: + return torch.distributed.get_world_size(group=get_expert_tensor_parallel_group()) + + +def set_expert_tensor_parallel_world_size(world_size): + "Set expert tensor model parallel size" + global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE + _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = world_size + + +def get_expert_tensor_parallel_rank(): + """Return my rank for the expert tensor parallel group.""" + global _MPU_EXPERT_TENSOR_PARALLEL_RANK + if _MPU_EXPERT_TENSOR_PARALLEL_RANK is not None: + return _MPU_EXPERT_TENSOR_PARALLEL_RANK + # Use tensor parallel group rank for backward compability otherwise + if not _EXPERT_TENSOR_PARALLEL_GROUP: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + else: + return torch.distributed.get_rank(group=get_expert_tensor_parallel_group()) + + +def set_expert_tensor_parallel_rank(rank): + "Set expert tensor model parallel rank" + global _MPU_EXPERT_TENSOR_PARALLEL_RANK + _MPU_EXPERT_TENSOR_PARALLEL_RANK = rank + + +def get_expert_tensor_and_model_parallel_group(check_initialized=True): + """Get the tensor- and expert-parallel group the caller rank belongs to.""" + if check_initialized: + assert ( + _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is not None + ), 'Expert tensor and model parallel group is not initialized' + return _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP + + +def get_expert_tensor_and_model_parallel_world_size(): + """Return world size for the expert model parallel group times expert tensor parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank( - group=get_data_modulo_expert_parallel_group(with_context_parallel=with_context_parallel) + world_size = torch.distributed.get_world_size( + group=get_expert_tensor_and_model_parallel_group() ) + return world_size else: return 0 -def get_tensor_and_expert_parallel_rank(): +def get_expert_tensor_and_model_parallel_rank(): """Return caller's rank in the joint tensor- and expert-model-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank(group=get_tensor_and_expert_parallel_group()) + return torch.distributed.get_rank(group=get_expert_tensor_and_model_parallel_group()) else: return 0 +def get_expert_tensor_model_pipeline_parallel_group(): + assert ( + _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is not None + ), 'Expert tensor-model-pipeline parallel group is not initialized' + return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP + + +def get_expert_data_parallel_group(): + assert _EXPERT_DATA_PARALLEL_GROUP is not None, 'Expert data parallel group is not initialized' + return _EXPERT_DATA_PARALLEL_GROUP + + +def get_expert_data_parallel_group_gloo(): + assert ( + _EXPERT_DATA_PARALLEL_GROUP_GLOO is not None + ), 'Expert data parallel group-gloo is not initialized' + return _EXPERT_DATA_PARALLEL_GROUP_GLOO + + +def get_expert_data_parallel_rank(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank(group=get_expert_data_parallel_group()) + else: + return 0 + + +### End of expert-related functions region + + def _set_global_memory_buffer(): """Initialize global buffer.""" global _GLOBAL_MEMORY_BUFFER @@ -1618,9 +1660,6 @@ def destroy_model_parallel(): global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None - global _MODEL_AND_EXPERT_PARALLEL_GROUP - _MODEL_AND_EXPERT_PARALLEL_GROUP = None - global _TENSOR_MODEL_PARALLEL_GROUP _TENSOR_MODEL_PARALLEL_GROUP = None @@ -1657,18 +1696,6 @@ def destroy_model_parallel(): global _TENSOR_AND_CONTEXT_PARALLEL_GROUP _TENSOR_AND_CONTEXT_PARALLEL_GROUP = None - global _EXPERT_MODEL_PARALLEL_GROUP - _EXPERT_MODEL_PARALLEL_GROUP = None - - global _TENSOR_AND_EXPERT_PARALLEL_GROUP - _TENSOR_AND_EXPERT_PARALLEL_GROUP = None - - global _DATA_MODULO_EXPERT_PARALLEL_GROUP - _DATA_MODULO_EXPERT_PARALLEL_GROUP = None - - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None @@ -1690,27 +1717,49 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None - global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE - _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None - - global _MPU_EXPERT_MODEL_PARALLEL_RANK - _MPU_EXPERT_MODEL_PARALLEL_RANK = None - global _DATA_PARALLEL_GROUP_GLOO if _DATA_PARALLEL_GROUP_GLOO is not None: torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_GLOO) _DATA_PARALLEL_GROUP_GLOO = None global _DATA_PARALLEL_GROUP_WITH_CP_GLOO + if _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None: + torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_WITH_CP_GLOO) _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO - if _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None: - torch.distributed.destroy_process_group(_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO) - _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None + ### Expert-related parallel states destory + global _EXPERT_MODEL_PARALLEL_GROUP + _EXPERT_MODEL_PARALLEL_GROUP = None + + global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None + + global _MPU_EXPERT_MODEL_PARALLEL_RANK + _MPU_EXPERT_MODEL_PARALLEL_RANK = None + + global _EXPERT_TENSOR_PARALLEL_GROUP + _EXPERT_TENSOR_PARALLEL_GROUP = None + + global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE + _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = None + + global _MPU_EXPERT_TENSOR_PARALLEL_RANK + _MPU_EXPERT_TENSOR_PARALLEL_RANK = None + + global _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP + _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = None + + global _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP + _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = None + + global _EXPERT_DATA_PARALLEL_GROUP + _EXPERT_DATA_PARALLEL_GROUP = None - global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO - _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None + global _EXPERT_DATA_PARALLEL_GROUP_GLOO + if _EXPERT_DATA_PARALLEL_GROUP_GLOO is not None: + torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO) + _EXPERT_DATA_PARALLEL_GROUP_GLOO = None + ### End of expert-related parallel states destory global _MOE_LAYER_WISE_LOGGING_TRACKER _MOE_LAYER_WISE_LOGGING_TRACKER = {} diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index 41d87431fe..00bfe4f452 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -18,12 +18,10 @@ all_to_all_sp2hp, copy_to_tensor_model_parallel_region, gather_from_sequence_parallel_region, - gather_from_sequence_parallel_region_to_moe, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region, reduce_scatter_to_sequence_parallel_region, - reduce_scatter_to_sequence_parallel_region_from_moe, scatter_to_sequence_parallel_region, scatter_to_tensor_model_parallel_region, ) @@ -71,6 +69,4 @@ "split_tensor_along_last_dim", "split_tensor_into_1d_equal_chunks", "gather_split_1d_tensor", - "gather_from_sequence_parallel_region_to_moe", - "reduce_scatter_to_sequence_parallel_region_from_moe", ] diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 12d2be69a9..fde8c106f1 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -14,9 +14,9 @@ from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.parallel_state import ( + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, get_global_memory_buffer, - get_tensor_and_expert_parallel_rank, - get_tensor_and_expert_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -107,16 +107,14 @@ def maybe_copy(attribute): maybe_copy(attribute) -def _initialize_affine_weight_gpu( - weight, init_method, partition_dim, stride=1, expert_parallel=False -): +def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1, is_expert=False): """Initialize affine weight for model parallel on GPU.""" set_tensor_model_parallel_attributes( tensor=weight, is_parallel=True, dim=partition_dim, stride=stride ) - if not expert_parallel: + if not is_expert: with get_cuda_rng_tracker().fork(): init_method(weight) else: @@ -756,15 +754,13 @@ def __init__( self.config = config self.disable_grad_reduce = disable_grad_reduce - self.explicit_expert_comm = self.is_expert and ( - config.tensor_model_parallel_size > 1 or self.expert_parallel - ) - if self.explicit_expert_comm and config.moe_extended_tp: - world_size = get_tensor_and_expert_parallel_world_size() - rank = get_tensor_and_expert_parallel_rank() + if is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() else: world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() + self.explicit_expert_comm = self.is_expert and (world_size > 1 or self.expert_parallel) self.output_size_per_partition = divide(output_size, world_size) @@ -807,7 +803,7 @@ def __init__( init_method, partition_dim=0, stride=stride, - expert_parallel=(self.is_expert and self.expert_parallel), + is_expert=self.is_expert, ) setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel)) @@ -1056,17 +1052,14 @@ def __init__( if self.sequence_parallel and not self.input_is_parallel: raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") - self.explicit_expert_comm = self.is_expert and ( - config.tensor_model_parallel_size > 1 or self.expert_parallel - ) - # Divide the weight matrix along the last dimension. - if self.explicit_expert_comm and config.moe_extended_tp: - world_size = get_tensor_and_expert_parallel_world_size() - rank = get_tensor_and_expert_parallel_rank() + if self.is_expert: + world_size = get_expert_tensor_parallel_world_size() + rank = get_expert_tensor_parallel_rank() else: world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() + self.explicit_expert_comm = self.is_expert and (world_size > 1 or self.expert_parallel) self.input_size_per_partition = divide(input_size, world_size) @@ -1109,7 +1102,7 @@ def __init__( init_method, partition_dim=1, stride=stride, - expert_parallel=(self.is_expert and self.expert_parallel), + is_expert=self.is_expert, ) setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel)) diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index 3d541d2f02..cdd7206871 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -3,9 +3,7 @@ import torch from megatron.core.parallel_state import ( - get_expert_model_parallel_group, get_global_memory_buffer, - get_tensor_and_expert_parallel_group, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -54,11 +52,12 @@ def _split_along_last_dim(input_): return output -def _split_along_first_dim(input_): +def _split_along_first_dim(input_, group=None): """Split the tensor along its first dimension and keep the corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() + if group is None: + group = get_tensor_model_parallel_group() + world_size = torch.distributed.get_world_size(group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -69,7 +68,7 @@ def _split_along_first_dim(input_): dim_size % world_size == 0 ), "First dimension of the tensor should be divisible by tensor parallel size" local_dim_size = dim_size // world_size - rank = get_tensor_model_parallel_rank() + rank = torch.distributed.get_rank(group) dim_offset = rank * local_dim_size output = input_[dim_offset : dim_offset + local_dim_size].contiguous() @@ -112,7 +111,7 @@ def _reduce_scatter_along_last_dim(input_): return output -def _gather_along_first_dim(input_, output_split_sizes=None): +def _gather_along_first_dim(input_, group=None, output_split_sizes=None, use_global_buffer=False): """Gather tensors and concatenate along the first dimension. Args: @@ -126,7 +125,9 @@ def _gather_along_first_dim(input_, output_split_sizes=None): torch.Tensor: Gathered tensor. """ - world_size = get_tensor_model_parallel_world_size() + if group is None: + group = get_tensor_model_parallel_group() + world_size = torch.distributed.get_world_size(group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -135,20 +136,26 @@ def _gather_along_first_dim(input_, output_split_sizes=None): if output_split_sizes is None: dim_size[0] = dim_size[0] * world_size - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - dist_all_gather_func(output, input_.contiguous(), group=get_tensor_model_parallel_group()) + if use_global_buffer: + output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + else: + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + dist_all_gather_func(output, input_.contiguous(), group=group) else: dim_size[0] = sum(output_split_sizes) - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + if use_global_buffer: + output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + else: + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) output_tensor_list = list(torch.split(output, output_split_sizes, dim=0)) - torch.distributed.all_gather( - output_tensor_list, input_, group=get_tensor_model_parallel_group() - ) + torch.distributed.all_gather(output_tensor_list, input_, group=group) return output -def _reduce_scatter_along_first_dim(input_, input_split_sizes=None): +def _reduce_scatter_along_first_dim( + input_, group=None, input_split_sizes=None, use_global_buffer=False +): """Reduce-scatter the input tensor across model parallel group. Args: @@ -157,7 +164,9 @@ def _reduce_scatter_along_first_dim(input_, input_split_sizes=None): the input splits along the first dimension for each rank. If None, equal splitting is assumed. Default: None. """ - world_size = get_tensor_model_parallel_world_size() + if group is None: + group = get_tensor_model_parallel_group() + world_size = torch.distributed.get_world_size(group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -170,74 +179,22 @@ def _reduce_scatter_along_first_dim(input_, input_split_sizes=None): dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - dist_reduce_scatter_func( - output, input_.contiguous(), group=get_tensor_model_parallel_group() - ) + if use_global_buffer: + output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + else: + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + dist_reduce_scatter_func(output, input_.contiguous(), group=group) else: - rank = torch.distributed.get_rank(get_tensor_model_parallel_group()) + rank = torch.distributed.get_rank(group) input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) - output = torch.empty_like(input_tensor_list[rank]) - torch.distributed.reduce_scatter( - output, input_tensor_list, group=get_tensor_model_parallel_group() - ) - return output - - -def _gather_along_first_dim_moe(input_, use_global_buffer=False): - """Gather tensors and concatenate along the first dimension.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - if use_global_buffer: - output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") - else: - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - dist_all_gather_func(output, input_.contiguous(), group=group) - - return output - - -def _reduce_scatter_along_first_dim_moe(input_, use_global_buffer=False): - """Reduce-scatter the input tensor across model parallel group.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - assert dim_size[0] % world_size == 0 - dim_size[0] = dim_size[0] // world_size - - if use_global_buffer: - output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") - else: - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - dist_reduce_scatter_func(output, input_.contiguous(), group=group) - return output - - -def _gather_along_first_dim_expert_parallel(input_): - """Gather tensors and concatenate along the first dimension.""" - group = get_expert_model_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - dist_all_gather_func(output, input_.contiguous(), group=group) + if use_global_buffer: + output = get_global_memory_buffer().get_tensor( + input_tensor_list[rank].shape, input_.dtype, "mpu" + ) + else: + output = torch.empty_like(input_tensor_list[rank]) + torch.distributed.reduce_scatter(output, input_tensor_list, group=group) return output @@ -340,16 +297,32 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function): """Gather the input from sequence parallel region and concatinate.""" @staticmethod - def symbolic(graph, input_, tensor_parallel_output_grad=True, output_split_sizes=None): + def symbolic( + graph, + input_, + tensor_parallel_output_grad=True, + group=None, + output_split_sizes=None, + use_global_buffer=False, + ): """Symbolic function for tracing.""" - return _gather_along_first_dim(input_, output_split_sizes) + return _gather_along_first_dim(input_, group, output_split_sizes, use_global_buffer) @staticmethod - def forward(ctx, input_, tensor_parallel_output_grad=True, output_split_sizes=None): + def forward( + ctx, + input_, + tensor_parallel_output_grad=True, + group=None, + output_split_sizes=None, + use_global_buffer=False, + ): """Forward function.""" ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + ctx.group = group ctx.output_split_sizes = output_split_sizes - return _gather_along_first_dim(input_, ctx.output_split_sizes) + ctx.use_global_buffer = use_global_buffer + return _gather_along_first_dim(input_, group, output_split_sizes, use_global_buffer) @staticmethod def backward(ctx, grad_output): @@ -362,76 +335,46 @@ def backward(ctx, grad_output): # output gradients need to be scattered. if tensor_parallel_output_grad: return ( - _reduce_scatter_along_first_dim(grad_output, ctx.output_split_sizes), + _reduce_scatter_along_first_dim( + grad_output, ctx.group, ctx.output_split_sizes, ctx.use_global_buffer + ), + None, + None, None, None, ) else: assert ctx.output_split_sizes is None - return _split_along_first_dim(grad_output), None, None + return _split_along_first_dim(grad_output, ctx.group), None, None, None, None class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): """Reduce scatter the input from the model parallel region.""" @staticmethod - def symbolic(graph, input_, input_split_sizes=None): + def symbolic(graph, input_, group=None, input_split_sizes=None, use_global_buffer=False): """Symbolic function for tracing.""" - return _reduce_scatter_along_first_dim(input_, input_split_sizes) + return _reduce_scatter_along_first_dim(input_, group, input_split_sizes, use_global_buffer) @staticmethod - def forward(ctx, input_, input_split_sizes=None): + def forward(ctx, input_, group=None, input_split_sizes=None, use_global_buffer=False): """Forward function.""" + ctx.group = group ctx.input_split_sizes = input_split_sizes - return _reduce_scatter_along_first_dim(input_, input_split_sizes) - - @staticmethod - def backward(ctx, grad_output): - """Backward function.""" - input_split_sizes = ctx.input_split_sizes - return _gather_along_first_dim(grad_output, input_split_sizes), None - - -class _GatherFromSequenceParallelRegionToMOE(torch.autograd.Function): - """Gather the input from model parallel region and concatenate.""" # TODO - - @staticmethod - def symbolic(graph, input_, use_global_buffer=False): - """Symbolic function for tracing.""" - return _gather_along_first_dim_moe(input_, use_global_buffer) - - @staticmethod - def forward(ctx, input_, use_global_buffer=False): - """Forward function.""" ctx.use_global_buffer = use_global_buffer - return _gather_along_first_dim_moe(input_, use_global_buffer) - - @staticmethod - def backward(ctx, grad_output): - """Backward function.""" - use_global_buffer = ctx.use_global_buffer - return _reduce_scatter_along_first_dim_moe(grad_output, use_global_buffer), None - - -class _ReduceScatterToSequenceParallelRegionFromMOE(torch.autograd.Function): - """Reduce scatter the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_, use_global_buffer=False): - """Symbolic function for tracing.""" - return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer) - - @staticmethod - def forward(ctx, input_, use_global_buffer=False): - """Forward function.""" - ctx.use_global_buffer = use_global_buffer - return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer) + return _reduce_scatter_along_first_dim(input_, group, input_split_sizes, use_global_buffer) @staticmethod def backward(ctx, grad_output): """Backward function.""" + input_split_sizes = ctx.input_split_sizes use_global_buffer = ctx.use_global_buffer - return _gather_along_first_dim_moe(grad_output, use_global_buffer), None + return ( + _gather_along_first_dim(grad_output, ctx.group, input_split_sizes, use_global_buffer), + None, + None, + None, + ) class _AllGatherFromTensorParallelRegion(torch.autograd.Function): @@ -522,61 +465,59 @@ def backward(ctx, *grad_output): def copy_to_tensor_model_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: copy, backward allreduce""" return _CopyToModelParallelRegion.apply(input_) def reduce_from_tensor_model_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: all reduce, backward copy""" return _ReduceFromModelParallelRegion.apply(input_) def scatter_to_tensor_model_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: RS, backward: AG """ return _ScatterToModelParallelRegion.apply(input_) def gather_from_tensor_model_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: AG, backward: split """ return _GatherFromModelParallelRegion.apply(input_) def scatter_to_sequence_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: split, backward: AG """ return _ScatterToSequenceParallelRegion.apply(input_) def gather_from_sequence_parallel_region( - input_, tensor_parallel_output_grad=True, output_split_sizes=None + input_, + tensor_parallel_output_grad=True, + group=None, + output_split_sizes=None, + use_global_buffer=False, ): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: AG, backward: RS """ return _GatherFromSequenceParallelRegion.apply( - input_, tensor_parallel_output_grad, output_split_sizes + input_, tensor_parallel_output_grad, group, output_split_sizes, use_global_buffer ) -def reduce_scatter_to_sequence_parallel_region(input_, input_split_sizes=None): - """Wrapper for autograd function""" - return _ReduceScatterToSequenceParallelRegion.apply(input_, input_split_sizes) - - -def gather_from_sequence_parallel_region_to_moe(input_, use_global_buffer=False): - """Wrapper for autograd function""" - return _GatherFromSequenceParallelRegionToMOE.apply(input_, use_global_buffer) - - -def reduce_scatter_to_sequence_parallel_region_from_moe(input_, use_global_buffer=False): - """Wrapper for autograd function""" - return _ReduceScatterToSequenceParallelRegionFromMOE.apply(input_, use_global_buffer) +def reduce_scatter_to_sequence_parallel_region( + input_, group=None, input_split_sizes=None, use_global_buffer=False +): + """Wrapper for autograd function: forward: RS, backward AG """ + return _ReduceScatterToSequenceParallelRegion.apply( + input_, group, input_split_sizes, use_global_buffer + ) def all_gather_last_dim_from_tensor_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: AG, backward RS """ return _AllGatherFromTensorParallelRegion.apply(input_) def reduce_scatter_last_dim_to_tensor_parallel_region(input_): - """Wrapper for autograd function""" + """Wrapper for autograd function: forward: RS, backward AG: AG """ return _ReduceScatterToTensorParallelRegion.apply(input_) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 4b144d4163..f3d4ab772f 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -14,6 +14,7 @@ from megatron.core.parallel_state import ( get_expert_model_parallel_rank, + get_expert_tensor_parallel_rank, get_tensor_model_parallel_rank, ) from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_data @@ -198,13 +199,16 @@ def model_parallel_cuda_manual_seed(seed): initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. - Two set of RNG states are tracked: + Three set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model parallel groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. + expert-parallel-seed: This state is only used for the expert layer of MoE models. + It is different among expert-tensor and expert-model parallel GPUs, and the same + across expert-data parallel groups. """ # 2718 is just for fun and any POSITIVE value will work. offset = seed + 2718 @@ -222,7 +226,7 @@ def model_parallel_cuda_manual_seed(seed): _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) expert_parallel_seed = ( - seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank() + seed + 1024 + 100 * get_expert_model_parallel_rank() + get_expert_tensor_parallel_rank() ) _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index eeb2838cd2..e08f94f2c3 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -53,6 +53,7 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit | --- | --- | | --num-experts | Number of Experts in MoE (None means no MoE) | | --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. | +| --expert-tensor-parallel-size | Degree of tensor model parallelism of expert layer. Default is same to --tensor-model-parallel-size. | | --moe-grouped-gemm | When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine. | | --moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". | | --moe-router-topk | Number of experts to route to for each token. The default is 2. | @@ -65,7 +66,6 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit | --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. | | --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. | | --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. | -| --moe-extended-tp | (Experimental) Alternative parallelization strategy for expert parallelism. Instead of distributing experts across *expert_model_parallel_size*, each expert is sharded along extendended tensor parallel domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing problem with MOE training. Only available with `--moe-token-dispatcher-type allgather`. | | --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. | | --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. | | --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.| @@ -328,6 +328,21 @@ Here we provide some general rules to get better performance: - The efficiency of CP largely depends on whether its communication can be overlapped with computation. - Emperically, use CP when sequence length >= 8K. +### MoE Parallel Folding + +MoE Parallel Folding separates the MoE related parallel groups from Dense groups. +1. Traditional MoE parallel groups are entangled with dense by using a 5-dimension parallel group generator with default order `tp-cp-ep-dp-pp`. The EP group in MoE is a sub-group of DP in Attention. +2. With MoE Parallel Fodling, we use a parallel group generator with `tp-cp-dp-pp` for Attention, and another with `tp-ep-dp-pp` for MoE. The EPxTP group in MoE is a sub-group of DPxCPxTP in Attention. + +By setting `--expert-tensor-parallel-size`, we can set MoE-specific TP size. + +#### Advantages of MoE Parallel Folding +1. The CP and EP group are folded together by defualt, such that: + 1. It reduces the minimal required GPUs to turn on both CP and EP. For example, the traditional way with (CP=8, EP=8) needs at least 64 GPUs, for now it only requires 8 GPUs. + 2. The CP and EP communication can be both put in the NVLink domain. +2. We can set different TP sizes for Attention and MoE part. + 1. For MoE, EP is often more efficient than TP. But in the traditional way, only using EP can get OOM for most models. + 2. With MoE parallel folding, we can turn on TP for Attention part and setting TP=1 for MoE models, which often gets better MFU. ### End-to-End Training Practice **Use the latest NVIDIA PyTorch or NeMo Docker Image** @@ -352,7 +367,7 @@ Here we provide some general rules to get better performance: **OOM Caused by Token Distribution Imbalance when Training From Scratch** MoE suffers from a severe load imbalance issue when the router is under-trained, leading to the model easily running out of memory (OOM), which typically occurs in the first 100~300 steps when training from scratch. Therefore, there are two recommended ways during the first 200 steps to avoid the OOM problem, which can be removed after the token distribution is more stable: -1. Use Extended-TP(`-moe-extended-tp`) to replace EP with TP in MoELayer, this can prevent the load imbalancing between EP ranks. Since current ETP implementation has some memeory overhead, you can further enable activation recomputation only for MoE Layer by adding `--moe-layer-recompute`. +1. Increase the `expert-tensor-parallel-size` and decrease `expert-model-parallel-size` to replace EP with TP in MoELayer, this can prevent the load imbalancing between EP ranks. Since current ETP implementation has some memeory overhead, you can further enable activation recomputation only for MoE Layer by adding `--moe-layer-recompute`. 2. Setting capacity factor to a relatively small number like 1.0 by adding `--moe-token-capacity-factor 1.0`. ### Reference Best Parallel Mapping diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index f037ea2f0a..8389547de3 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -2,7 +2,7 @@ import itertools from copy import deepcopy -from functools import partial +from functools import partial, wraps from math import ceil from typing import Optional, Tuple @@ -46,6 +46,44 @@ HAVE_TE = False +def expert_dist_ckpt_decorator(func): + """Decorator of shared_state_dict in expert layer for distributed checkpoint. + + Since !1940, the TP size for Expert layer can be different with Attention. + To make distributed checkpoint work in such cases, we use a decorator to + replace the default TP parallel states with expert-TP parallel states. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # Store original states + original_rank = parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK + original_size = parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + original_group = parallel_state._TENSOR_MODEL_PARALLEL_GROUP + try: + # Set new states + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = ( + parallel_state.get_expert_tensor_parallel_rank() + ) + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = ( + parallel_state.get_expert_tensor_parallel_world_size() + ) + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = ( + parallel_state.get_expert_tensor_parallel_group() + ) + + # Execute the function + result = func(*args, **kwargs) + finally: + # Restore original states + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = original_rank + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = original_size + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = original_group + return result + + return wrapper + + class GroupedMLP(MegatronModule): """An efficient implementation of the Experts layer using GroupedGEMM. @@ -76,11 +114,8 @@ def glu(x): self.activation_func = self.config.activation_func # How many feature each rank holds for fc1 and fc2, respectively. - self.moe_extended_tp = config.moe_extended_tp - if config.moe_extended_tp: - tp_size = parallel_state.get_tensor_and_expert_parallel_world_size() - else: - tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_size = parallel_state.get_expert_tensor_parallel_world_size() + tp_rank = parallel_state.get_expert_tensor_parallel_rank() fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts if config.gated_linear_unit: @@ -119,6 +154,8 @@ def glu(x): partition_dim=1, init_method=config.init_method, params_dtype=config.params_dtype, + rank=tp_rank, + world_size=tp_size, ) _initialize_affine_weight_cpu( self.weight2, @@ -128,6 +165,8 @@ def glu(x): partition_dim=0, init_method=config.output_layer_init_method, params_dtype=config.params_dtype, + rank=tp_rank, + world_size=tp_size, ) else: self.weight1 = Parameter( @@ -148,16 +187,10 @@ def glu(x): ) if config.perform_initialization: _initialize_affine_weight_gpu( - self.weight1, - config.init_method, - partition_dim=1, - expert_parallel=self.expert_parallel, + self.weight1, config.init_method, partition_dim=1, is_expert=True ) _initialize_affine_weight_gpu( - self.weight2, - config.output_layer_init_method, - partition_dim=0, - expert_parallel=self.expert_parallel, + self.weight2, config.output_layer_init_method, partition_dim=0, is_expert=True ) setattr(self.weight1, 'allreduce', not self.expert_parallel) setattr(self.weight2, 'allreduce', not self.expert_parallel) @@ -203,6 +236,7 @@ def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: return fc2_output, None + @expert_dist_ckpt_decorator def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """ Maps local expert to global experts. @@ -210,11 +244,6 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): whereas the optimizer states are not due to the limitation from weight transposing. That is, for finetuning scenario, the checkpoint is compatible with the SequentialMLP. """ - if self.moe_extended_tp: - raise NotImplementedError( - 'Currently distributed checkpointing is not supported for moe_extended_tp' - ) - sharded_state_dict = {} num_global_experts = ( parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts @@ -226,11 +255,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): tp_rank = parallel_state.get_tensor_model_parallel_rank() prepend_axis_num = len(sharded_offsets) - replica_id = ( - 0, - 0, - parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True), - ) + replica_id = (0, 0, parallel_state.get_expert_data_parallel_rank()) local_ffn_dim_size = ( self.weight2.numel() // self.num_local_experts // self.config.hidden_size @@ -542,7 +567,7 @@ def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool): replica_id = ( 0, parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True), + parallel_state.get_expert_data_parallel_rank(), ) # Add fake _extra_state to be compatible with SequentialMLP for expert_local_idx in range(self.num_local_experts): @@ -572,7 +597,6 @@ class TEGroupedMLP(MegatronModule): def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules): super().__init__(config=config) - self.moe_extended_tp = config.moe_extended_tp self.num_local_experts = num_local_experts self.input_size = self.config.hidden_size @@ -685,6 +709,7 @@ def glu(x): return output, output_bias + @expert_dist_ckpt_decorator def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: @@ -692,10 +717,6 @@ def sharded_state_dict( Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP's. """ - if self.moe_extended_tp: - raise NotImplementedError( - 'Currently distributed checkpointing is not supported for moe_extended_tp' - ) sharded_state_dict = {} for name, module in self._modules.items(): sub_sd = module.sharded_state_dict(f'{name}.', sharded_offsets, metadata) @@ -730,7 +751,6 @@ class SequentialMLP(MegatronModule): def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules): super().__init__(config=config) self.add_bias = config.add_bias_linear - self.moe_extended_tp = config.moe_extended_tp self.num_local_experts = num_local_experts self.local_experts = torch.nn.ModuleList() for _ in range(self.num_local_experts): @@ -786,13 +806,9 @@ def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: return output_local, output_bias_local + @expert_dist_ckpt_decorator def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """Maps local expert to global experts.""" - if self.moe_extended_tp: - raise NotImplementedError( - 'Currently distributed checkpointing is not supported for moe_extended_tp' - ) - sharded_state_dict = {} num_global_experts = ( parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts @@ -825,7 +841,7 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' sh_ten.replica_id = ( *replica_id[:2], - parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True), + parallel_state.get_expert_data_parallel_rank(), ) sharded_state_dict.update(expert_state_dict) diff --git a/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py b/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py index 326742484f..dd5f447dd3 100644 --- a/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py +++ b/megatron/core/transformer/moe/legacy_a2a_token_dispatcher.py @@ -6,7 +6,6 @@ import torch.distributed from megatron.core import parallel_state, tensor_parallel -from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel from megatron.core.transformer.moe.moe_utils import ( get_capacity, permute, @@ -150,8 +149,8 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: .to(torch.device("cpu"), non_blocking=True) .numpy() ) - num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel( - num_local_tokens_per_expert + num_global_tokens_per_expert = tensor_parallel.gather_from_sequence_parallel_region( + num_local_tokens_per_expert, group=self.ep_group ).reshape(ep_size, self.num_experts) self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 7c01f8208a..faefce4cf0 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -42,15 +42,11 @@ def __init__(self, config: TransformerConfig, layer_number: int = None): self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size() assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size" - if self.config.moe_extended_tp: - self.num_local_experts = self.config.num_moe_experts - local_expert_indices_offset = 0 - else: - assert self.config.num_moe_experts % self.expert_parallel_size == 0 - self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.num_local_experts - ) + assert self.config.num_moe_experts % self.expert_parallel_size == 0 + self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * self.num_local_experts + ) self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None self.shared_expert_overlap = self.config.moe_shared_expert_overlap diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index e35d64fa2e..0c1504d417 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import math +from typing import Optional import torch @@ -230,7 +231,7 @@ def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_i def topk_softmax_with_capacity( logits: torch.Tensor, topk: int, - capacity_factor: float = None, + capacity_factor: Optional[float] = None, pad_to_capacity: bool = False, drop_policy: str = "probs", use_pre_softmax: bool = False, diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 3d84f993ef..5db0d19fad 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -5,9 +5,16 @@ import torch -from megatron.core import parallel_state, tensor_parallel -from megatron.core.tensor_parallel.mappings import ( - _gather_along_first_dim_moe, +from megatron.core.parallel_state import ( + get_expert_model_parallel_group, + get_expert_model_parallel_world_size, + get_expert_tensor_and_model_parallel_group, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, + get_expert_tensor_parallel_world_size, +) +from megatron.core.tensor_parallel import ( + all_to_all, gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, ) @@ -43,6 +50,14 @@ def __init__(self, config: TransformerConfig) -> None: self.config = config self.shared_experts: Optional[SharedExpertMLP] = None + if torch.distributed.is_available() and torch.distributed.is_initialized(): + self.ep_group = get_expert_model_parallel_group() + self.ep_size = get_expert_model_parallel_world_size() + self.tp_group = get_expert_tensor_parallel_group() + self.tp_size = get_expert_tensor_parallel_world_size() + self.tp_rank = get_expert_tensor_parallel_rank() + self.tp_ep_group = get_expert_tensor_and_model_parallel_group() + @abstractmethod def token_permutation( self, tokens: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor @@ -131,25 +146,23 @@ def token_permutation( hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) # Permute the tokens across the expert parallel devices. - if (self.config.tensor_model_parallel_size > 1) or ( - self.config.expert_model_parallel_size > 1 - ): + if self.tp_size > 1 or self.ep_size > 1: ## local_indices calculation with torch.no_grad(): # [num_local_tokens, num_experts] -> [num_global_tokens, num_experts], where: # num_local_tokens=(S/TP)*B, num_global_tokens=S*B*EP - routing_map = tensor_parallel.gather_from_sequence_parallel_region_to_moe( - routing_map + routing_map = gather_from_sequence_parallel_region( + routing_map, group=self.tp_ep_group ) ## local_probs calculation # max_prob: [S/TP*B, num_experts] -> global_probs: [S*B*EP, num_experts] - probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(probs) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) # Note that this allgather spans the communication domain of TP*EP. # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] - hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe( - hidden_states, use_global_buffer=True + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group, use_global_buffer=True ) self.hidden_shape_before_permute = hidden_states.shape @@ -210,20 +223,18 @@ def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = output_bias_total = unpermuted_local_bias # Unpermute the tokens across ranks. - if (self.config.tensor_model_parallel_size > 1) or ( - self.config.expert_model_parallel_size > 1 - ): - output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( - output_total + if self.tp_size > 1 or self.ep_size > 1: + output_total = reduce_scatter_to_sequence_parallel_region( + output_total, group=self.tp_ep_group ) if self.add_bias: # Unpermute the bias across expert parallel devices. # bias is duplicated across tensor parallelism ranks; output_bias_total = ( - tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( - output_bias_total + reduce_scatter_to_sequence_parallel_region( + output_bias_total, group=self.tp_ep_group ) - / parallel_state.get_tensor_model_parallel_world_size() + / self.tp_size ) output_total = output_total.view(self.hidden_shape) @@ -236,6 +247,11 @@ def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): """ AlltoAll-based token dispatcher. + + The workflow of AlltoAll token dispatcher is as follows: + (1) preprocess(): calculate necessary metadata for communication and permute + (2) token_permutation(): permute->A2A(EP)->AG(TP)->sort_chunk(if num_local_experts>1) + (3) token_unpermutation(): sort_chunk(if num_local_experts>1)->RS(TP)->A2A(EP)->unpermute """ def __init__( @@ -262,8 +278,6 @@ def __init__( assert ( self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 ), "local_expert_indices must be continous" - self.ep_size = config.expert_model_parallel_size - self.tp_size = config.tensor_model_parallel_size self.probs = None # [ep_size]. Represents the number of tokens sent by the current rank to other @@ -324,7 +338,6 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # [num_experts], number of tokens assigned to each expert from the current rank's input. num_local_tokens_per_expert = routing_map.sum(dim=0).long() - tp_rank = parallel_state.get_tensor_model_parallel_rank() if self.drop_and_pad: # Drop and pad the input to capacity. num_tokens = routing_map.size(0) * self.config.moe_router_topk @@ -380,7 +393,9 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # expert by all ranks. # [tp_size, ep_size, num_experts] num_global_tokens_per_expert = ( - _gather_along_first_dim_moe(num_local_tokens_per_expert) + gather_from_sequence_parallel_region( + num_local_tokens_per_expert, group=self.tp_ep_group + ) .reshape(self.ep_size, self.tp_size, self.num_experts) .transpose(0, 1) ) @@ -394,7 +409,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: # self.output_splits represents the number of tokens received by the current rank # from other EP rank. self.output_splits = ( - num_global_tokens_per_rank[tp_rank] + num_global_tokens_per_rank[self.tp_rank] .to(torch.device("cpu"), non_blocking=True) .numpy() ) @@ -471,18 +486,16 @@ def token_permutation( # Perform expert parallel AlltoAll communication if self.cuda_sync_point == "before_ep_alltoall": torch.cuda.current_stream().synchronize() - global_input_tokens = tensor_parallel.all_to_all( - parallel_state.get_expert_model_parallel_group(), - permutated_local_input_tokens, - self.output_splits, - self.input_splits, + global_input_tokens = all_to_all( + self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits ) if self.shared_experts is not None: self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) - if parallel_state.get_tensor_model_parallel_world_size() > 1: + if self.tp_size > 1: global_input_tokens = gather_from_sequence_parallel_region( global_input_tokens, + group=self.tp_group, output_split_sizes=( self.output_splits_tp.tolist() if self.output_splits_tp is not None else None ), @@ -502,7 +515,7 @@ def token_permutation( return global_input_tokens, tokens_per_expert def token_unpermutation( - self, hidden_states: torch.Tensor, bias: torch.Tensor = None + self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Reverse the token permutation to restore the original order. @@ -531,9 +544,10 @@ def token_unpermutation( self.restore_output_by_local_experts, ) - if parallel_state.get_tensor_model_parallel_world_size() > 1: + if self.tp_size > 1: hidden_states = reduce_scatter_to_sequence_parallel_region( hidden_states, + group=self.tp_group, input_split_sizes=( self.output_splits_tp.tolist() if self.output_splits_tp is not None else None ), @@ -541,11 +555,8 @@ def token_unpermutation( # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] - permutated_local_input_tokens = tensor_parallel.all_to_all( - parallel_state.get_expert_model_parallel_group(), - hidden_states, - self.input_splits, - self.output_splits, + permutated_local_input_tokens = all_to_all( + self.ep_group, hidden_states, self.input_splits, self.output_splits ) if self.shared_experts is not None: self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 28c1830e63..48ad00cf66 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -526,17 +526,13 @@ def __post_init__(self): self.init_method_std, self.num_layers ) - if self.moe_extended_tp: - if self.moe_token_dispatcher_type != 'allgather': - raise ValueError( - "Moe extended TP parallelism only applies to allgather based token dispatcher." - ) - extended_tp_size = self.tensor_model_parallel_size * self.expert_model_parallel_size - if self.ffn_hidden_size % extended_tp_size != 0: - raise ValueError( - f'ffn_hidden_size: {self.ffn_hidden_size} must be divisible by ' - f'extended_tp_size {extended_tp_size}' - ) + if ( + self.moe_token_dispatcher_type == "alltoall_seq" + and self.tensor_model_parallel_size != self.expert_tensor_parallel_size + ): + raise ValueError( + "alltoall_seq dispatcher not support different TP size for MoE and Dense layer." + ) if self.num_moe_experts and self.fp8: # TE version below 1.7.0 will raise Error when handle zeros tokens for expert diff --git a/megatron/legacy/model/transformer.py b/megatron/legacy/model/transformer.py index dda550551a..db48d607e7 100644 --- a/megatron/legacy/model/transformer.py +++ b/megatron/legacy/model/transformer.py @@ -20,14 +20,14 @@ from megatron.core.jit import jit_fuser from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.parallel_state import ( - get_tensor_and_expert_parallel_group, + get_expert_tensor_and_model_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel import ( - gather_from_sequence_parallel_region_to_moe, + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, - reduce_scatter_to_sequence_parallel_region_from_moe, ) from megatron.legacy.model.enums import AttnMaskType, AttnType, LayerType from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl @@ -221,10 +221,11 @@ def __init__(self, config): for i in range(self.num_local_experts): self.local_experts.append(ParallelMLP(config, is_expert=True)) + self.tp_ep_group = get_expert_tensor_and_model_parallel_group() + def gather_indices(self, local_indices): """ Gather tensors and concatinate along the first dimension.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) + world_size = torch.distributed.get_world_size(group=self.tp_ep_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return local_indices @@ -236,7 +237,7 @@ def gather_indices(self, local_indices): output = torch.empty(dim_size, dtype=local_indices.dtype, device=torch.cuda.current_device()) torch.distributed._all_gather_base( - output, local_indices.contiguous(), group=group + output, local_indices.contiguous(), group=self.tp_ep_group ) return output @@ -269,7 +270,7 @@ def forward(self, hidden_states): # Each vector could be routed differently if self.sequence_parallel or (self.expert_parallel_size > 1): global_hidden_states = \ - gather_from_sequence_parallel_region_to_moe(hidden_states) + gather_from_sequence_parallel_region(hidden_states, group=self.tp_ep_group) global_indices = self.gather_indices(max_ind) else: global_hidden_states = hidden_states @@ -291,10 +292,10 @@ def forward(self, hidden_states): if self.sequence_parallel or (self.expert_parallel_size > 1): output_total = \ - reduce_scatter_to_sequence_parallel_region_from_moe(output_total) + reduce_scatter_to_sequence_parallel_region(output_total, group=self.tp_ep_group) if self.add_bias: output_bias_total = \ - reduce_scatter_to_sequence_parallel_region_from_moe(output_bias_total) + reduce_scatter_to_sequence_parallel_region(output_bias_total, group=self.tp_ep_group) # bias is duplicated across tensor parallelism ranks; # reduce scatter reduces bias across tensor parallel_ranks diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a4c5ae87ff..19a2086124 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -5,13 +5,12 @@ import argparse import dataclasses import json -import logging import os -import torch import types import warnings from packaging.version import Version as PkgVersion +import torch import torch.nn.functional as F from megatron.core.dist_checkpointing.validation import StrictHandling @@ -229,6 +228,9 @@ def validate_args(args, defaults={}): assert args.hierarchical_context_parallel_sizes is not None, \ "--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm" + if args.expert_tensor_parallel_size is None: + args.expert_tensor_parallel_size = args.tensor_model_parallel_size + # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ 'valid, use --micro-batch-size instead' @@ -1963,6 +1965,8 @@ def _add_moe_args(parser): group = parser.add_argument_group(title="moe") group.add_argument('--expert-model-parallel-size', type=int, default=1, help='Degree of expert model parallelism.') + group.add_argument('--expert-tensor-parallel-size', type=int, default=None, + help='Degree of expert model parallelism. Default is None, which will be set to the value of --tensor-model-paralle-size.') group.add_argument('--num-experts', type=int, default=None, help='Number of Experts in MoE (None means no MoE)') group.add_argument('--moe-shared-expert-intermediate-size', type=int, default=None, @@ -2005,7 +2009,7 @@ def _add_moe_args(parser): group.add_argument('--moe-layer-recompute', action='store_true', help='Enable checkpointing for moe_layer, should be used when memory is not sufficient.') group.add_argument('--moe-extended-tp', action='store_true', - help='Alternative to expert parallelism, all experts are sharded across TPXEP domain.') + help='Deprecated. Use --expert-tensor-parallel-size instead.') group.add_argument('--moe-use-upcycling', action='store_true', help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. ' 'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.') diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 12d50bd278..777461b9a8 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -391,7 +391,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati # Collect args, model, RNG. if not torch.distributed.is_initialized() \ - or mpu.get_data_modulo_expert_parallel_rank(with_context_parallel=True) == 0 \ + or mpu.get_expert_data_parallel_rank() == 0 \ or ckpt_type != CheckpointType.LEGACY: optim_sd_kwargs = {} if ckpt_type != CheckpointType.LEGACY and args.use_distributed_optimizer: diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index f72c1b9eb8..a0861c9f85 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -284,6 +284,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, expert_model_parallel_size=args.expert_model_parallel_size, + expert_tensor_parallel_size=args.expert_tensor_parallel_size, distributed_timeout_minutes=args.distributed_timeout_minutes, nccl_communicator_config_path=args.nccl_communicator_config_path, order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp', diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 9c6e95c1ad..59bee81476 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -65,8 +65,9 @@ def calc_params_l2_norm(model): args = get_args() if not isinstance(model, list): model = [model] - # Remove duplicate params. + # Seperate moe and dense params params_data = [] + moe_params_data = [] data_parallel_group = None for model_chunk in model: @@ -76,17 +77,16 @@ def calc_params_l2_norm(model): if not (param.requires_grad and is_not_tp_duplicate): continue assert is_not_tp_duplicate - if mpu.get_expert_model_parallel_rank() > 0: - if not getattr(param, 'allreduce', True): - assert param_is_not_shared(param) - param = to_local_if_dtensor(param) - params_data.append(param.data.float() if args.bf16 else param.data) + if not getattr(param, 'allreduce', True): + assert param_is_not_shared(param) + param = to_local_if_dtensor(param) + moe_params_data.append(param.data.float() if args.bf16 else param.data) else: if param_is_not_shared(param): param = to_local_if_dtensor(param) params_data.append(param.data.float() if args.bf16 else param.data) - # Calculate norm + # Calculate dense param norm dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') norm, _ = multi_tensor_applier( multi_tensor_l2norm, @@ -101,19 +101,28 @@ def calc_params_l2_norm(model): op=torch.distributed.ReduceOp.SUM, group=data_parallel_group) - if mpu.get_expert_model_parallel_world_size() == 1: - # Sum across all model-parallel GPUs(tensor + pipeline). - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_model_parallel_group()) - else: - # Sum across tensor, pipeline and expert model-parallel GPUs. - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_tensor_and_expert_parallel_group()) - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_pipeline_model_parallel_group()) + # Sum across all model-parallel GPUs(tensor + pipeline). + torch.distributed.all_reduce( + norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group() + ) + # Calculate moe norm + if len(moe_params_data) > 0: + moe_norm, _ = multi_tensor_applier( + multi_tensor_l2norm, + dummy_overflow_buf, + [moe_params_data], + False # no per-parameter norm + ) + moe_norm_2 = moe_norm * moe_norm + # Sum across expert tensor, model and pipeline parallel GPUs. + torch.distributed.all_reduce( + moe_norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_expert_tensor_model_pipeline_parallel_group() + ) + norm_2 += moe_norm_2 return norm_2.item() ** 0.5 diff --git a/tests/functional_tests/jet_recipes/gpt.yaml b/tests/functional_tests/jet_recipes/gpt.yaml index 3ee2581981..f252510c1f 100644 --- a/tests/functional_tests/jet_recipes/gpt.yaml +++ b/tests/functional_tests/jet_recipes/gpt.yaml @@ -71,6 +71,7 @@ products: - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_groupedGEMM_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_top2router_dgx_a100_1N8G + - gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_overlap_grad_reduce_param_gather_groupedGEMM_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_top2router_dgx_a100_1N8G diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_dev.json new file mode 100644 index 0000000000..36c9e2356a --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_dev.json @@ -0,0 +1,493 @@ +{ + "forward-backward-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 5.87989, + 0.25748, + 0.25366, + 0.25572, + 0.2567, + 0.25799, + 0.26476, + 0.26513, + 0.27047, + 0.26564 + ] + }, + "forward-compute-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 3.77461, + 0.14169, + 0.13928, + 0.14013, + 0.14114, + 0.14295, + 0.14946, + 0.14968, + 0.15533, + 0.1511 + ] + }, + "backward-compute-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.70676, + 0.11366, + 0.11287, + 0.11354, + 0.11325, + 0.11292, + 0.11324, + 0.114, + 0.11328, + 0.11353 + ] + }, + "batch-generator-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.53331, + 0.00182, + 0.00166, + 0.00153, + 0.00159, + 0.00154, + 0.00168, + 0.00158, + 0.00165, + 0.00159 + ] + }, + "layernorm-grads-all-reduce-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00268, + 0.00176, + 0.00167, + 0.00206, + 0.00204, + 0.0017, + 0.00191, + 0.00171, + 0.002, + 0.00164 + ] + }, + "embedding-grads-all-reduce-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 7e-05, + 4e-05, + 4e-05, + 5e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05 + ] + }, + "all-grads-sync-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.39476, + 0.00284, + 0.00279, + 0.00279, + 0.00281, + 0.00285, + 0.00281, + 0.00279, + 0.00282, + 0.00279 + ] + }, + "optimizer-copy-to-main-grad-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00037, + 0.0003, + 0.00028, + 0.00026, + 0.00024, + 0.00027, + 0.00027, + 0.00026, + 0.00023, + 0.00022 + ] + }, + "optimizer-inner-step-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00756, + 0.0018, + 0.00179, + 0.00178, + 0.00179, + 0.00178, + 0.00179, + 0.0018, + 0.00177, + 0.00176 + ] + }, + "optimizer-copy-main-to-model-params-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00143, + 0.00111, + 0.00111, + 0.0011, + 0.00109, + 0.0011, + 0.0011, + 0.0011, + 0.00108, + 0.00115 + ] + }, + "optimizer-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.52684, + 0.01306, + 0.01274, + 0.01275, + 0.01268, + 0.01284, + 0.01269, + 0.01278, + 0.01244, + 0.01255 + ] + }, + "learning-rate": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "learning-rate vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "batch-size": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0 + ] + }, + "batch-size vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0 + ] + }, + "lm loss": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 10.81298, + 10.87741, + 10.87628, + 10.80047, + 10.67764, + 10.5788, + 10.06451, + 10.18736, + 10.08297, + 9.75169 + ] + }, + "lm loss vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 10.81298, + 10.87741, + 10.87628, + 10.80047, + 10.67764, + 10.5788, + 10.06451, + 10.18736, + 10.08297, + 9.75169 + ] + }, + "loss-scale": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ] + }, + "loss-scale vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ] + }, + "grad-norm": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8.33414, + 5.78016, + 5.87842, + 6.80216, + 6.7125, + 6.39007, + 8.68862, + 5.16113, + 4.57425, + 4.41469 + ] + }, + "grad-norm vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8.33414, + 5.78016, + 5.87842, + 6.80216, + 6.7125, + 6.39007, + 8.68862, + 5.16113, + 4.57425, + 4.41469 + ] + }, + "num-zeros": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 26888.0, + 32285.0, + 33214.0, + 31691.0, + 28562.0, + 30589.0, + 28925.0, + 33010.0, + 33385.0, + 35045.0 + ] + }, + "num-zeros vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 26888.0, + 32285.0, + 33214.0, + 31691.0, + 28562.0, + 30589.0, + 28925.0, + 33010.0, + 33385.0, + 35045.0 + ] + }, + "params-norm": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 262.92148, + 262.92148, + 262.92148, + 262.92148, + 262.92145, + 262.92145, + 262.92142, + 262.9213, + 262.92111, + 262.92087 + ] + }, + "params-norm vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 262.92148, + 262.92148, + 262.92148, + 262.92148, + 262.92145, + 262.92145, + 262.92142, + 262.9213, + 262.92111, + 262.92087 + ] + }, + "load_balancing_loss": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.03508, + 1.03273, + 1.02893, + 1.03497, + 1.04648, + 1.04875, + 1.09296, + 1.10445, + 1.12111, + 1.13657 + ] + }, + "iteration-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 7.81347, + 0.28438, + 0.27865, + 0.2808, + 0.28157, + 0.28301, + 0.28981, + 0.29022, + 0.29452, + 0.28987 + ] + }, + "lm loss validation": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 9.79266 + ] + }, + "lm loss validation vs samples": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 9.79266 + ] + }, + "lm loss validation ppl": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 17901.80664 + ] + }, + "lm loss validation ppl vs samples": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 17901.80664 + ] + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_lts.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_lts.json new file mode 100644 index 0000000000..45b9cdd270 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/golden_values_lts.json @@ -0,0 +1,493 @@ +{ + "forward-backward-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 13.47392, + 0.25841, + 0.27289, + 0.25653, + 0.26625, + 0.25628, + 0.26339, + 0.26204, + 0.2749, + 0.28151 + ] + }, + "forward-compute-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8.79707, + 0.14316, + 0.15675, + 0.14123, + 0.15065, + 0.14186, + 0.14773, + 0.14675, + 0.15897, + 0.16523 + ] + }, + "backward-compute-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.73122, + 0.11386, + 0.1138, + 0.11348, + 0.11317, + 0.11208, + 0.11347, + 0.11357, + 0.11427, + 0.11465 + ] + }, + "batch-generator-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.77139, + 0.0019, + 0.00182, + 0.00185, + 0.00185, + 0.00197, + 0.00171, + 0.00165, + 0.00182, + 0.00166 + ] + }, + "layernorm-grads-all-reduce-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00311, + 0.00225, + 0.0023, + 0.00216, + 0.00213, + 0.00207, + 0.00206, + 0.00196, + 0.00208, + 0.00197 + ] + }, + "embedding-grads-all-reduce-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05, + 4e-05 + ] + }, + "all-grads-sync-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 4.01852, + 0.00289, + 0.00287, + 0.00289, + 0.00286, + 0.00286, + 0.00285, + 0.00294, + 0.00296, + 0.00282 + ] + }, + "optimizer-copy-to-main-grad-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00047, + 0.00032, + 0.00033, + 0.0003, + 0.00031, + 0.00028, + 0.00025, + 0.00026, + 0.00027, + 0.00026 + ] + }, + "optimizer-inner-step-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00803, + 0.00182, + 0.00185, + 0.00182, + 0.00184, + 0.00179, + 0.00184, + 0.00178, + 0.0018, + 0.00179 + ] + }, + "optimizer-copy-main-to-model-params-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.00153, + 0.00114, + 0.00114, + 0.00113, + 0.00114, + 0.00112, + 0.00117, + 0.00111, + 0.00111, + 0.0011 + ] + }, + "optimizer-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 2.65854, + 0.01318, + 0.01283, + 0.01264, + 0.01264, + 0.01242, + 0.01289, + 0.01226, + 0.01232, + 0.01228 + ] + }, + "learning-rate": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "learning-rate vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + }, + "batch-size": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0 + ] + }, + "batch-size vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0, + 32.0 + ] + }, + "lm loss": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 10.81298, + 10.87741, + 10.87628, + 10.80047, + 10.67764, + 10.5788, + 10.06451, + 10.18736, + 10.08297, + 9.75169 + ] + }, + "lm loss vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 10.81298, + 10.87741, + 10.87628, + 10.80047, + 10.67764, + 10.5788, + 10.06451, + 10.18736, + 10.08297, + 9.75169 + ] + }, + "loss-scale": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ] + }, + "loss-scale vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ] + }, + "grad-norm": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8.33414, + 5.78016, + 5.87842, + 6.80216, + 6.7125, + 6.39007, + 8.68862, + 5.16113, + 4.57425, + 4.41469 + ] + }, + "grad-norm vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 8.33414, + 5.78016, + 5.87842, + 6.80216, + 6.7125, + 6.39007, + 8.68862, + 5.16113, + 4.57425, + 4.41469 + ] + }, + "num-zeros": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 26888.0, + 32285.0, + 33214.0, + 31691.0, + 28562.0, + 30589.0, + 28925.0, + 33010.0, + 33385.0, + 35045.0 + ] + }, + "num-zeros vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 26888.0, + 32285.0, + 33214.0, + 31691.0, + 28562.0, + 30589.0, + 28925.0, + 33010.0, + 33385.0, + 35045.0 + ] + }, + "params-norm": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 262.92148, + 262.92148, + 262.92148, + 262.92148, + 262.92145, + 262.92145, + 262.92142, + 262.9213, + 262.92111, + 262.92087 + ] + }, + "params-norm vs samples": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 262.92148, + 262.92148, + 262.92148, + 262.92148, + 262.92145, + 262.92145, + 262.92142, + 262.9213, + 262.92111, + 262.92087 + ] + }, + "load_balancing_loss": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 1.03508, + 1.03273, + 1.02893, + 1.03497, + 1.04648, + 1.04875, + 1.09296, + 1.10445, + 1.12111, + 1.13657 + ] + }, + "iteration-time": { + "start_step": 0, + "end_step": 50, + "step_interval": 5, + "values": [ + 16.86916, + 0.28405, + 0.29778, + 0.28081, + 0.29056, + 0.28009, + 0.28785, + 0.28603, + 0.29846, + 0.30491 + ] + }, + "lm loss validation": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 9.79266 + ] + }, + "lm loss validation vs samples": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 9.79266 + ] + }, + "lm loss validation ppl": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 17901.80664 + ] + }, + "lm loss validation ppl vs samples": { + "start_step": 0, + "end_step": 2, + "step_interval": 5, + "values": [ + 17901.80664 + ] + } +} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..85b76573a8 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts_etp1_ep4_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,59 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + N_REPEATS: 5 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --expert-model-parallel-size: 4 + --expert-tensor-parallel-size: 1 + --disable-bias-linear: true + --sequence-parallel: true + --num-experts: 8 + --moe-router-load-balancing-type: aux_loss + --moe-router-topk: 2 + --moe-aux-loss-coeff: 1e-2 + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --moe-grouped-gemm: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: regular diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py index aab901b50a..e5e3ac98bd 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py @@ -87,37 +87,63 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.parametrize( - "use_fpsl,src_tp_pp_exp,dest_tp_pp_exp,use_glu", + "use_fpsl,src_tp_pp_ep_etp,dest_tp_pp_ep_etp,use_glu", [ # changing PP is impossible because the number of layers must be the same - (False, (2, 4, 1), (2, 4, 1), False), - (True, (2, 4, 1), (2, 4, 1), False), - (False, (1, 1, 1), (1, 1, 1), False), - (True, (1, 1, 1), (1, 1, 4), False), - (False, (1, 1, 8), (1, 1, 2), False), - (False, (2, 2, 2), (4, 2, 1), False), - (True, (1, 1, 4), (8, 1, 1), False), - (False, (1, 8, 1), (1, 8, 1), False), - (False, (1, 1, 4), (2, 1, 1), False), - (False, (1, 1, 1), (1, 1, 1), True), - (False, (1, 1, 1), (1, 1, 4), True), - (True, (1, 1, 1), (2, 1, 1), True), - (False, (1, 1, 4), (8, 1, 1), True), + (False, (2, 4, 1, 2), (2, 4, 1, 2), False), + (True, (2, 4, 1, 2), (2, 4, 1, 2), False), + (False, (2, 4, 1, 2), (1, 4, 1, 2), False), + (True, (2, 1, 1, 2), (1, 1, 1, 2), False), + (False, (1, 1, 1, 1), (1, 1, 1, 1), False), + (True, (1, 1, 1, 1), (1, 1, 4, 1), False), + (False, (1, 1, 8, 1), (1, 1, 2, 1), False), + (False, (2, 2, 2, 2), (4, 2, 1, 4), False), + (True, (1, 1, 4, 1), (8, 1, 1, 1), False), + (False, (1, 8, 1, 1), (1, 8, 1, 1), False), + (False, (1, 1, 4, 1), (2, 1, 1, 2), False), + (False, (2, 1, 4, 1), (2, 1, 1, 4), False), + (False, (1, 1, 1, 1), (1, 1, 1, 1), True), + (False, (1, 1, 1, 1), (1, 1, 4, 1), True), + (True, (1, 1, 1, 1), (2, 1, 1, 1), True), + (False, (1, 1, 4, 1), (8, 1, 1, 8), True), ], ) @pytest.mark.parametrize("expert_type", expert_type) + @pytest.mark.parametrize( + "load_order,store_order", + [ + ("tp-ep-dp-pp", "tp-ep-dp-pp"), + # ("tp-ep-dp-pp", "ep-tp-dp-pp"), + # ("ep-tp-dp-pp", "ep-tp-dp-pp"), + # ("ep-tp-dp-pp", "tp-ep-dp-pp"), + ], + ) def test_parallel_reconfiguration_e2e( - self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl, expert_type + self, + tmp_path_dist_ckpt, + src_tp_pp_ep_etp, + dest_tp_pp_ep_etp, + use_glu, + use_fpsl, + expert_type, + load_order, + store_order, ): - """Test model saving and loading with different TP/PP/expert parallelism""" - src_tp, src_pp, src_exp = src_tp_pp_exp - dest_tp, dest_pp, dest_exp = dest_tp_pp_exp + """Test model saving and loading with different TP/PP/EP/ETP(expert-tensor-parallel)""" + src_tp, src_pp, src_ep, src_etp = src_tp_pp_ep_etp + dest_tp, dest_pp, dest_ep, dest_etp = dest_tp_pp_ep_etp if expert_type == 'grouped': add_bias_linear = False else: add_bias_linear = True # Save checkpoint A - Utils.initialize_model_parallel(src_tp, src_pp, expert_model_parallel_size=src_exp) + Utils.initialize_model_parallel( + src_tp, + src_pp, + expert_model_parallel_size=src_ep, + expert_tensor_parallel_size=src_etp, + order=store_order, + ) with TempNamedDir( tmp_path_dist_ckpt / 'test_expert_layer_reconfiguration_model_A' ) as ckpt_dir_A, TempNamedDir( @@ -138,9 +164,15 @@ def test_parallel_reconfiguration_e2e( save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() - # Load checkpoint A with different TP/PP/expert and save as checkpoint B + # Load checkpoint A with different TP/PP/EP and save as checkpoint B # No FPS this time, only FPL - Utils.initialize_model_parallel(dest_tp, dest_pp, expert_model_parallel_size=dest_exp) + Utils.initialize_model_parallel( + dest_tp, + dest_pp, + expert_model_parallel_size=dest_ep, + expert_tensor_parallel_size=dest_etp, + order=load_order, + ) model_B = initialize_expert_layer( 1, use_glu, expert_type, add_bias_linear=add_bias_linear ) diff --git a/tests/unit_tests/tensor_parallel/test_mappings.py b/tests/unit_tests/tensor_parallel/test_mappings.py index d5bc3f2127..3c5536f27a 100644 --- a/tests/unit_tests/tensor_parallel/test_mappings.py +++ b/tests/unit_tests/tensor_parallel/test_mappings.py @@ -1,3 +1,4 @@ +import pytest import torch from megatron.core.tensor_parallel import mappings @@ -90,6 +91,7 @@ def test_ScatterToSequenceParallelRegion(): Utils.destroy_model_parallel() +@pytest.mark.internal def test_GatherFromSequenceParallelRegion(): Utils.initialize_model_parallel(4, 2) input_data = torch.ones(4).cuda() * Utils.rank @@ -110,6 +112,8 @@ def test_GatherFromSequenceParallelRegion(): class Ctx: tensor_parallel_output_grad = True output_split_sizes = None + group = None + use_global_buffer = False output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) expected_output = torch.ones((1, 4)).cuda() * 4 * int(Utils.rank % 4) @@ -117,6 +121,7 @@ class Ctx: Utils.destroy_model_parallel() +@pytest.mark.internal def test_ReduceScatterToSequenceParallelRegion(): Utils.initialize_model_parallel(4, 2) input_data = torch.vstack( @@ -133,12 +138,14 @@ def test_ReduceScatterToSequenceParallelRegion(): class Ctx: input_split_sizes = None + group = None + use_global_buffer = False - output_data, _ = mappings._ReduceScatterToSequenceParallelRegion.backward(Ctx(), input_data) + output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(Ctx(), input_data) expected_output = torch.concat( (torch.ones(4) * 0, torch.ones(4) * 1, torch.ones(4) * 2, torch.ones(4) * 3) ).cuda() if Utils.rank >= 4: expected_output = expected_output + 4 - assert torch.equal(output_data, expected_output) + assert torch.equal(output_data[0], expected_output) Utils.destroy_model_parallel() diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 9778822aad..ca5185b28e 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -1,5 +1,3 @@ -import os - import pytest import torch @@ -40,6 +38,10 @@ def test_initialize_and_destroy_model_parallel(order): assert ps.get_tensor_model_parallel_group() is not None assert ps.get_pipeline_model_parallel_group() is not None assert ps.get_data_parallel_group() is not None + assert ps.get_expert_model_parallel_group() is not None + assert ps.get_expert_tensor_parallel_group() is not None + assert ps.get_expert_data_parallel_group() is not None + assert ps.get_expert_tensor_model_pipeline_parallel_group() is not None Utils.destroy_model_parallel() assert ps._MODEL_PARALLEL_GROUP is None @@ -74,6 +76,15 @@ def test_tensor_model_parellel_world_size(order): Utils.destroy_model_parallel() +@pytest.mark.parametrize('order', test_parallel_order) +def test_expert_tensor_parellel_world_size(order): + Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) + assert ps.get_expert_tensor_parallel_world_size() == world_size + ps.set_expert_tensor_parallel_world_size(None) + assert ps.get_expert_tensor_parallel_world_size() == world_size + Utils.destroy_model_parallel() + + @pytest.mark.parametrize('order', test_parallel_order) def test_pipeline_model_parallel_world_size(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) @@ -92,6 +103,15 @@ def test_tensor_model_parallel_rank(order): Utils.destroy_model_parallel() +@pytest.mark.parametrize('order', test_parallel_order) +def test_moe_tensor_model_parellel_rank(order): + Utils.initialize_model_parallel(expert_tensor_parallel_size=world_size, order=order) + assert ps.get_expert_tensor_parallel_rank() == rank + ps.set_expert_tensor_parallel_rank(None) + assert ps.get_expert_tensor_parallel_rank() == rank + Utils.destroy_model_parallel() + + @pytest.mark.parametrize('order', test_parallel_order) def test_pipeline_model_parallel_rank(order): Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size, order=order) @@ -167,6 +187,7 @@ def test_encoder_tensor_pipeline_parallelism(order): Utils.destroy_model_parallel() +@pytest.mark.internal @pytest.mark.parametrize( 'src_tp_pp, ep_size', [ @@ -192,12 +213,12 @@ def test_different_initialize_order_consistency(src_tp_pp, ep_size): tp_g = torch.distributed.get_process_group_ranks(ps.get_tensor_model_parallel_group()) dp_g = torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) pp_g = torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) - dp_no_ep_g = torch.distributed.get_process_group_ranks( - ps.get_data_modulo_expert_parallel_group() - ) + dp_no_ep_g = torch.distributed.get_process_group_ranks(ps.get_expert_data_parallel_group()) cp_g = torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) mp_g = torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) - tp_ep_g = torch.distributed.get_process_group_ranks(ps.get_tensor_and_expert_parallel_group()) + tp_ep_g = torch.distributed.get_process_group_ranks( + ps.get_expert_tensor_and_model_parallel_group() + ) tp_dp_g = torch.distributed.get_process_group_ranks( ps.get_tensor_and_data_parallel_group(False) ) @@ -216,12 +237,12 @@ def test_different_initialize_order_consistency(src_tp_pp, ep_size): assert dp_g == torch.distributed.get_process_group_ranks(ps.get_data_parallel_group(False)) assert pp_g == torch.distributed.get_process_group_ranks(ps.get_pipeline_model_parallel_group()) assert dp_no_ep_g == torch.distributed.get_process_group_ranks( - ps.get_data_modulo_expert_parallel_group() + ps.get_expert_data_parallel_group() ) assert cp_g == torch.distributed.get_process_group_ranks(ps.get_context_parallel_group()) assert mp_g == torch.distributed.get_process_group_ranks(ps.get_model_parallel_group()) assert tp_ep_g == torch.distributed.get_process_group_ranks( - ps.get_tensor_and_expert_parallel_group() + ps.get_expert_tensor_and_model_parallel_group() ) assert tp_dp_g == torch.distributed.get_process_group_ranks( ps.get_tensor_and_data_parallel_group(False) @@ -261,6 +282,7 @@ def test_different_initialize_order_unconsistency(src_tp_pp, ep_size): Utils.destroy_model_parallel() +@pytest.mark.internal @pytest.mark.parametrize( 'nodes, num_gpu, tp, pp, cp, ep', [ @@ -389,54 +411,37 @@ def golden_rank_result_from_past_code( ranks = ranks + list(range(start_rank, end_rank)) tp_dp_group.append(list(ranks)) - tp_ep_group = [] - dp_no_ep_group = [] - dp_no_ep_group_with_cp = [] + expert_tp_ep_group = [] + expert_dp_group = [] + expert_data_parallel_size = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * expert_model_parallel_size + ) all_ranks = torch.arange(world_size).reshape( ( pipeline_model_parallel_size, - data_parallel_size // expert_model_parallel_size, + expert_data_parallel_size, expert_model_parallel_size, - context_parallel_size, tensor_model_parallel_size, ) ) - # 'pp edp ep cp tp -> (pp edp cp) (ep tp)' - tp_ep_rearrange = torch.transpose(all_ranks, 2, 3) + # (pp, dp, ep, tp) -> (pp*dp, ep*tp) tp_ep_rearrange = torch.reshape( - tp_ep_rearrange, (-1, expert_model_parallel_size * tensor_model_parallel_size) + all_ranks, (-1, expert_model_parallel_size * tensor_model_parallel_size) ) - tp_ep_rearrange = tp_ep_rearrange.tolist() - tp_ep_rearrange.sort() - for tensor_and_expert_parallel_ranks in tp_ep_rearrange: - tensor_and_expert_parallel_ranks = list(tensor_and_expert_parallel_ranks) - tensor_and_expert_parallel_ranks.sort() - tp_ep_group.append(tensor_and_expert_parallel_ranks) - # 'pp edp ep cp tp -> (pp ep cp tp) edp' - edp_rearrange = torch.transpose(all_ranks, 1, 4) - edp_rearrange = torch.reshape( - edp_rearrange, (-1, data_parallel_size // expert_model_parallel_size) + num_tp_ep_groups = tp_ep_rearrange.shape[0] + for i in range(num_tp_ep_groups): + expert_tensor_and_model_parallel_ranks = tp_ep_rearrange[i].tolist() + expert_tp_ep_group.append(expert_tensor_and_model_parallel_ranks) + + # (pp, dp, ep, tp) -> (pp*ep*tp, dp) + expert_dp_rearrange = torch.permute(all_ranks, (0, 2, 3, 1)).reshape( + -1, expert_data_parallel_size ) - edp_rearrange = edp_rearrange.tolist() - edp_rearrange.sort() - for expert_data_parallel_ranks in edp_rearrange: - expert_data_parallel_ranks = list(expert_data_parallel_ranks) - expert_data_parallel_ranks.sort() - dp_no_ep_group.append(expert_data_parallel_ranks) - # 'pp edp ep cp tp -> (pp ep tp) (cp edp)' - edp_cp_rearrange = torch.transpose(all_ranks, 1, 2) - edp_cp_rearrange = torch.transpose(edp_cp_rearrange, 2, 4) - edp_cp_rearrange = torch.reshape( - edp_cp_rearrange, - (-1, context_parallel_size * data_parallel_size // expert_model_parallel_size), - ) - edp_cp_rearrange = edp_cp_rearrange.tolist() - edp_cp_rearrange.sort() - for expert_data_parallel_ranksj_with_cp in edp_cp_rearrange: - expert_data_parallel_ranksj_with_cp = list(expert_data_parallel_ranksj_with_cp) - expert_data_parallel_ranksj_with_cp.sort() - dp_no_ep_group_with_cp.append(expert_data_parallel_ranksj_with_cp) + num_expert_dp_groups = world_size // expert_data_parallel_size + for i in range(num_expert_dp_groups): + expert_dp_ranks = expert_dp_rearrange[i].tolist() + expert_dp_group.append(expert_dp_ranks) return ( dp_groups, @@ -447,13 +452,13 @@ def golden_rank_result_from_past_code( pp_group, tp_dp_group, tp_dp_cp_group, - tp_ep_group, - dp_no_ep_group, - dp_no_ep_group_with_cp, + expert_tp_ep_group, + expert_dp_group, ) world_size = nodes * num_gpu dp = world_size // (tp * pp * cp) + expert_dp = world_size // (tp * ep * pp) assert dp % ep == 0, f"dp size ({dp}) is not divisible by ep {ep} ." assert ( world_size % (tp * pp * cp) == 0 @@ -467,9 +472,8 @@ def golden_rank_result_from_past_code( pp_group, tp_dp_group, tp_dp_cp_group, - tp_ep_group, - dp_no_ep_group, - dp_no_ep_group_with_cp, + expert_tp_ep_group, + expert_dp_group, ) = golden_rank_result_from_past_code( world_size=world_size, tensor_model_parallel_size=tp, @@ -477,7 +481,10 @@ def golden_rank_result_from_past_code( context_parallel_size=cp, expert_model_parallel_size=ep, ) - rank_generator = ps.RankGenerator(tp=tp, ep=ep, dp=dp, pp=pp, cp=cp, order="tp-cp-ep-dp-pp") + rank_generator = ps.RankGenerator(tp=tp, ep=1, dp=dp, pp=pp, cp=cp, order="tp-cp-dp-pp") + expert_rank_generator = ps.RankGenerator( + tp=tp, ep=ep, dp=expert_dp, pp=pp, cp=1, order="tp-ep-dp-pp" + ) assert dp_groups == rank_generator.get_ranks( "dp" ), f"{dp_groups} != {rank_generator.get_ranks('dp')}" @@ -502,12 +509,9 @@ def golden_rank_result_from_past_code( assert tp_dp_cp_group == rank_generator.get_ranks( "tp-dp-cp" ), f"{tp_dp_cp_group} != {rank_generator.get_ranks('tp-dp-cp')}" - assert tp_ep_group == rank_generator.get_ranks( - "tp-ep", independent_ep=True - ), f"{tp_ep_group} != {rank_generator.get_ranks('tp-ep', independent_ep=True)}." - assert dp_no_ep_group == rank_generator.get_ranks( - "dp", independent_ep=True - ), f"{dp_no_ep_group} != {rank_generator.get_ranks('dp', independent_ep=True)}." - assert dp_no_ep_group_with_cp == rank_generator.get_ranks( - "dp-cp", independent_ep=True - ), f"{dp_no_ep_group_with_cp} != {rank_generator.get_ranks('dp-cp', independent_ep=True)}." + assert expert_tp_ep_group == expert_rank_generator.get_ranks( + "tp-ep" + ), f"{expert_tp_ep_group} != {expert_rank_generator.get_ranks('tp-ep')}." + assert expert_dp_group == expert_rank_generator.get_ranks( + "dp" + ), f"{expert_dp_group} != {expert_rank_generator.get_ranks('dp')}." diff --git a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py index 2e8f67fd44..bb834a9661 100644 --- a/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_a2a_token_dispatcher.py @@ -63,7 +63,7 @@ def test_capacity_forward_backward(self, tp_size, ep_size): moe_expert_capacity_factor=0.5, moe_pad_expert_input_to_capacity=False, ) - container.dispacher_capacity_test() + container.dispatcher_capacity_test() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal diff --git a/tests/unit_tests/transformer/moe/test_aux_loss.py b/tests/unit_tests/transformer/moe/test_aux_loss.py index 2b7b2e109b..50567e1930 100644 --- a/tests/unit_tests/transformer/moe/test_aux_loss.py +++ b/tests/unit_tests/transformer/moe/test_aux_loss.py @@ -18,6 +18,7 @@ def partition_input(self, input): output.requires_grad = True return output + @pytest.mark.internal def aux_loss_test(self, input, baseline_grad): partitioned_input = self.partition_input(input) moe_layer = self.moe_layer @@ -56,6 +57,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal @pytest.mark.parametrize( @@ -75,6 +77,7 @@ def test_allgather_dispatcher(self, tp_size, ep_size, cp_size): ) container.aux_loss_test(self.input, self.baseline_grad) + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal @pytest.mark.parametrize( diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 043bdc8c58..4748cbc887 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -312,6 +312,7 @@ def test_constructor(self): self.fc2_ffn_hidden_size, ) + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal def test_gpu_forward_backward(self): @@ -355,6 +356,7 @@ def test_gpu_forward_backward(self): for smm_result, gmm_result in zip(smm_results, gmm_results): torch.testing.assert_close(smm_result, gmm_result) + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal def test_gpu_forward_backward_with_no_tokens_allocated(self): diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index c1633834b6..2b3e098dbc 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -44,6 +44,7 @@ def test_constructor(self): num_weights = sum([p.numel() for p in self.router.parameters()]) assert num_weights == 12 * 4, num_weights + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) @@ -56,6 +57,7 @@ def test_router_forward(self, moe_router_pre_softmax): hidden_states = hidden_states.cuda() scores, indices = self.router(hidden_states) + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal def test_aux_loss(self): diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py index f473d409db..2a005555d5 100644 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py @@ -111,6 +111,7 @@ def setup_method(self, method): self.num_local_experts, self.transformer_config, self.te_mlp_spec ) + @pytest.mark.internal @pytest.mark.skipif( not is_te_min_version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", @@ -127,6 +128,7 @@ def test_constructor(self): self.te_sequential_mlp.local_experts[i].linear_fc2.weight, ) + @pytest.mark.internal @pytest.mark.skipif( not is_te_min_version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", @@ -149,6 +151,7 @@ def test_gpu_forward(self): output_te, _ = self.te_sequential_mlp(hidden_states, tokens_per_expert) assert torch.equal(output_local, output_te) + @pytest.mark.internal @pytest.mark.skipif( not is_te_min_version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", @@ -173,6 +176,7 @@ def test_gpu_forward_with_one_local_expert(self): output_te, _ = te_sequential_mlp(hidden_states, tokens_per_expert) assert torch.equal(output_local, output_te) + @pytest.mark.internal @pytest.mark.skipif( not is_te_min_version("1.7.0"), reason="Transformer Engine under v1.7.0 doesn't support MoE training.", diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index e85f8512b4..6bf79bbe7e 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -21,6 +21,7 @@ def __init__( ep_size, pp_size, cp_size=1, + moe_tp_size=None, data_parallel_random_init=False, num_moe_experts=8, moe_router_topk=2, @@ -32,11 +33,14 @@ def __init__( **kwargs, ): self.num_local_experts = num_moe_experts // ep_size + if moe_tp_size is None: + moe_tp_size = tp_size Utils.initialize_model_parallel( tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, expert_model_parallel_size=ep_size, context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, ) _set_random_seed(seed_=123, data_parallel_random_init=data_parallel_random_init) local_expert_indices_offset = ( @@ -45,12 +49,12 @@ def __init__( self.local_expert_indices = [ local_expert_indices_offset + i for i in range(self.num_local_experts) ] - self.config = TransformerConfig( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size, pipeline_model_parallel_size=pp_size, context_parallel_size=cp_size, + expert_tensor_parallel_size=moe_tp_size, moe_router_topk=moe_router_topk, num_moe_experts=num_moe_experts, moe_router_load_balancing_type=moe_router_load_balancing_type, @@ -59,9 +63,8 @@ def __init__( moe_pad_expert_input_to_capacity=moe_pad_expert_input_to_capacity, moe_aux_loss_coeff=moe_aux_loss_coeff, num_layers=1, - moe_extended_tp=kwargs.get("moe_extended_tp", False), moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False), - hidden_size=kwargs.get("hidden_size", 1024), + hidden_size=kwargs.get("hidden_size", 16), num_attention_heads=kwargs.get("num_attention_heads", 8), use_cpu_initialization=kwargs.get("use_cpu_initialization", True), sequence_parallel=tp_size > 1, @@ -69,19 +72,24 @@ def __init__( ) # init moe layer + self.moe_layer = self.new_moe_layer() + + def new_moe_layer(self): transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=kwargs.get("moe_grouped_gemm", False) + num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm ) - self.moe_layer = MoELayer( - self.config, transformer_layer_spec.submodules.mlp.submodules + moe_layer = MoELayer( + copy.deepcopy(self.config), transformer_layer_spec.submodules.mlp.submodules ).cuda() - self.moe_layer.set_layer_number(0) + moe_layer.set_layer_number(0) + return moe_layer def __del__(self): torch.distributed.barrier() torch.cuda.synchronize() Utils.destroy_model_parallel() + @pytest.mark.internal def dispatcher_dropless_test(self): moe_layer = self.moe_layer bs = 32 @@ -103,13 +111,7 @@ def dispatcher_dropless_test(self): moe_layer.token_dispatcher.token_permutation(hidden_states, probs, indices) ) - if self.config.moe_extended_tp: - scale = ( - moe_layer.config.tensor_model_parallel_size - * moe_layer.config.expert_model_parallel_size - ) - else: - scale = moe_layer.config.tensor_model_parallel_size + scale = moe_layer.config.expert_tensor_parallel_size permuted_local_hidden_states /= scale @@ -127,14 +129,13 @@ def dispatcher_dropless_test(self): hidden_states.grad, ans ), "Restored hidden states do not match original hidden states" - def dispacher_capacity_test(self): + @pytest.mark.internal + def dispatcher_capacity_test(self): moe_layer = self.moe_layer - hidden_states = torch.randn((256, moe_layer.config.hidden_size)) + hidden_states = torch.randn((16, moe_layer.config.hidden_size)) hidden_states = hidden_states.cuda() hidden_states.requires_grad = True probs, indices = moe_layer.router(hidden_states) - tp_size = moe_layer.config.tensor_model_parallel_size - tp_rank = parallel_state.get_tensor_model_parallel_rank() # Create the answer. prob_mask = probs != 0 @@ -163,27 +164,17 @@ def dispacher_capacity_test(self): hidden_states.grad, restored_hidden_states_answer ), "Gradient of hidden states should be same as hidden states" + @pytest.mark.internal def dispatcher_drop_and_pad_test(self): "Test if the tokens are dropped and padded correctly" moe_layer = self.moe_layer - moe_layer_2 = copy.deepcopy(moe_layer) - hidden_states = torch.randn((256, moe_layer.config.hidden_size)).cuda() + + hidden_states = torch.randn((16, moe_layer.config.hidden_size)).cuda() hidden_states.requires_grad = True - # Create the answer. moe_layer.config.moe_pad_expert_input_to_capacity = False moe_layer.token_dispatcher.drop_and_pad = False - # Uncomment these lines to help bug location. - # hidden_states = torch.ones((8, moe_layer.config.hidden_size)).cuda() - # hidden_states = hidden_states * torch.range(1, 8).unsqueeze(1).cuda() - # hidden_states.requires_grad = True - # indices_1 = torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]).cuda() - # probs_1 = torch.ones_like(indices_1) - # indices_2 = torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]).cuda() - # probs_2 = torch.ones_like(indices_2) - # num_local_tokens_per_expert = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2]).cuda() - probs_1, indices_1 = moe_layer.router(hidden_states) (permuted_input_1, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation( hidden_states, probs_1, indices_1 @@ -198,6 +189,11 @@ def dispatcher_drop_and_pad_test(self): torch.cuda.synchronize() # End + moe_layer_2 = self.new_moe_layer() + moe_layer_2.load_state_dict(moe_layer.state_dict()) + moe_layer_2.config.moe_pad_expert_input_to_capacity = True + moe_layer_2.token_dispatcher.drop_and_pad = True + probs_2, indices_2 = moe_layer_2.router(hidden_states) (permuted_input_2, tokens_per_expert) = moe_layer_2.token_dispatcher.token_permutation( hidden_states, probs_2, indices_2 @@ -231,6 +227,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal @pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4), (1, 1)]) @@ -247,19 +244,25 @@ def test_forward_backward(self, tp_size, ep_size): container.dispatcher_dropless_test() + @pytest.mark.internal @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.internal - @pytest.mark.parametrize("tp_size,ep_size", [(2, 4)]) - def test_extend_tp_forward_backward(self, tp_size, ep_size): + @pytest.mark.parametrize( + "tp_size,ep_size,moe_tp_size", [(1, 1, 8), (1, 2, 4), (1, 4, 2), (2, 2, 4)] + ) + def test_moe_tp_forward_backward(self, tp_size, ep_size, moe_tp_size): container = MoEModelTestContainer( tp_size=tp_size, ep_size=ep_size, pp_size=1, + moe_tp_size=moe_tp_size, num_moe_experts=8, moe_router_topk=2, moe_router_load_balancing_type="aux_loss", moe_token_dispatcher_type="allgather", - moe_extended_tp=True, + sequence_parallel=True, + moe_grouped_gemm=True, + use_cpu_initialization=False, ) container.dispatcher_dropless_test()