diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 095231b051..e9043b647c 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -227,7 +227,7 @@ def decompose(index, shape, stride=None): class RankGenerator(object): - """A class for generating rank groups based on various parallelism strategies.""" + """A class for generating rank groups for different modes of parallelism.""" def __init__( self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0 @@ -294,9 +294,9 @@ def get_mask(self, order: str, token: str): return mask def get_ranks(self, token, independent_ep=False): - '''Get rank group by input token. + """Get rank group by input token. - Arguments: + Args: token (str): Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen @@ -309,7 +309,7 @@ def get_ranks(self, token, independent_ep=False): EP, we should set the flag. For example, get_ranks('dp', True) will get DP modulo EP group, and get_ranks('dp', False) will get full DP group. - ''' + """ if independent_ep: parallel_size = self.ordered_size_w_ep order = self.order_w_ep @@ -884,7 +884,7 @@ def is_unitialized() -> bool: def model_parallel_is_initialized(): - """Check if model and data parallel groups are initialized.""" + """Check if model- and data-parallel groups are initialized.""" if ( _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None @@ -895,7 +895,7 @@ def model_parallel_is_initialized(): def get_model_parallel_group(with_expert_parallel=False): - """Get the model parallel group the caller rank belongs to.""" + """Get the model-parallel group the caller rank belongs to.""" if with_expert_parallel: assert ( _MODEL_AND_EXPERT_PARALLEL_GROUP is not None @@ -906,7 +906,7 @@ def get_model_parallel_group(with_expert_parallel=False): def get_tensor_model_parallel_group(check_initialized=True): - """Get the tensor model parallel group the caller rank belongs to.""" + """Get the tensor-model-parallel group the caller rank belongs to.""" if check_initialized: assert ( _TENSOR_MODEL_PARALLEL_GROUP is not None @@ -915,7 +915,7 @@ def get_tensor_model_parallel_group(check_initialized=True): def get_pipeline_model_parallel_group(): - """Get the pipeline model parallel group the caller rank belongs to.""" + """Get the pipeline-model-parallel group the caller rank belongs to.""" assert ( _PIPELINE_MODEL_PARALLEL_GROUP is not None ), 'pipeline_model parallel group is not initialized' @@ -923,7 +923,7 @@ def get_pipeline_model_parallel_group(): def get_data_parallel_group(with_context_parallel=False): - """Get the data parallel group the caller rank belongs to.""" + """Get the data-parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_PARALLEL_GROUP_WITH_CP is not None @@ -935,7 +935,7 @@ def get_data_parallel_group(with_context_parallel=False): def get_data_parallel_group_gloo(with_context_parallel=False): - """Get the data parallel group-gloo the caller rank belongs to.""" + """Get the Gloo data-parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None @@ -947,14 +947,14 @@ def get_data_parallel_group_gloo(with_context_parallel=False): def get_context_parallel_group(check_initialized=True): - """Get the context parallel group the caller rank belongs to.""" + """Get the context-parallel group the caller rank belongs to.""" if check_initialized: assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized' return _CONTEXT_PARALLEL_GROUP def get_context_parallel_global_ranks(check_initialized=True): - """Get all global ranks of the context parallel group that the caller rank belongs to.""" + """Get all global ranks of the context-parallel group that the caller rank belongs to.""" if check_initialized: assert ( _CONTEXT_PARALLEL_GLOBAL_RANKS is not None @@ -1001,7 +1001,7 @@ def get_amax_reduction_group(with_context_parallel=False, tp_only_amax_red=False def get_tensor_and_data_parallel_group(with_context_parallel=False): - """Get the tensor and data parallel group the caller rank belongs to.""" + """Get the tensor- and data-parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None @@ -1015,7 +1015,7 @@ def get_tensor_and_data_parallel_group(with_context_parallel=False): def get_tensor_and_context_parallel_group(): - """Get the tensor and context parallel group the caller rank belongs to.""" + """Get the tensor- and context-parallel group the caller rank belongs to.""" assert ( _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None ), 'tensor and context parallel group is not initialized' @@ -1023,7 +1023,7 @@ def get_tensor_and_context_parallel_group(): def get_expert_model_parallel_group(): - """Get the expert model parallel group the caller rank belongs to.""" + """Get the expert-model-parallel group the caller rank belongs to.""" assert ( _EXPERT_MODEL_PARALLEL_GROUP is not None ), 'expert model parallel group is not initialized' @@ -1031,7 +1031,7 @@ def get_expert_model_parallel_group(): def get_tensor_and_expert_parallel_group(): - """Get the tensor and expert parallel group the caller rank belongs to.""" + """Get the tensor- and expert-parallel group the caller rank belongs to.""" assert ( _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None ), 'tensor and expert parallel group is not initialized' @@ -1039,7 +1039,7 @@ def get_tensor_and_expert_parallel_group(): def get_data_modulo_expert_parallel_group(with_context_parallel=False): - """Get the data modulo expert parallel group the caller rank belongs to.""" + """Get the data-modulo-expert-parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is not None @@ -1053,7 +1053,7 @@ def get_data_modulo_expert_parallel_group(with_context_parallel=False): def get_data_modulo_expert_parallel_group_gloo(with_context_parallel=False): - """Get the data modulo expert parallel group gloo the caller rank belongs to.""" + """Get the Gloo data-modulo-expert-parallel group the caller rank belongs to.""" if with_context_parallel: assert ( _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO is not None @@ -1067,31 +1067,31 @@ def get_data_modulo_expert_parallel_group_gloo(with_context_parallel=False): def set_expert_model_parallel_world_size(world_size): - """Sets the expert model parallel world size.""" + """Sets the expert-model-parallel world size.""" global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size def set_tensor_model_parallel_world_size(world_size): - """Set the tensor model parallel size""" + """Set the tensor-model-parallel size""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size def set_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" + """Set the pipeline-model-parallel size""" global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size def set_virtual_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" + """Set the pipeline-model-parallel size""" global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" + """Return world size for the tensor-model-parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE @@ -1099,14 +1099,14 @@ def get_tensor_model_parallel_world_size(): def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" + """Return world size for the pipeline-model-parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE pp_group = get_pipeline_model_parallel_group() if isinstance(pp_group, list): - # I am assuming that each pp group is the same size. + # Implicit assumption that each PP group is the same size. sizes = [] for group in _PIPELINE_GLOBAL_RANKS: sizes.append(len(group)) @@ -1117,31 +1117,31 @@ def get_pipeline_model_parallel_world_size(): def set_expert_model_parallel_rank(rank): - """Set expert model parallel rank.""" + """Set expert-model-parallel rank.""" global _MPU_EXPERT_MODEL_PARALLEL_RANK _MPU_EXPERT_MODEL_PARALLEL_RANK = rank def set_tensor_model_parallel_rank(rank): - """Set tensor model parallel rank.""" + """Set tensor-model-parallel rank.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK _MPU_TENSOR_MODEL_PARALLEL_RANK = rank def set_pipeline_model_parallel_rank(rank): - """Set pipeline model parallel rank.""" + """Set pipeline-model-parallel rank.""" global _MPU_PIPELINE_MODEL_PARALLEL_RANK _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank def set_pipeline_model_parallel_split_rank(rank): - """Set pipeline model parallel split rank. DEPRECATED.""" + """Set pipeline-model-parallel split rank. DEPRECATED.""" global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" + """Return caller's rank for the tensor-model-parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: return _MPU_TENSOR_MODEL_PARALLEL_RANK @@ -1149,14 +1149,14 @@ def get_tensor_model_parallel_rank(): def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" + """Return caller's rank for the pipeline-model-parallel group.""" global _MPU_PIPELINE_MODEL_PARALLEL_RANK if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: return _MPU_PIPELINE_MODEL_PARALLEL_RANK rank = torch.distributed.get_rank() pp_group = get_pipeline_model_parallel_group() if isinstance(pp_group, list): - # I am assuming that if i exist in multiple pp groups, then I am in the same index. + # Assume that if the caller exist in multiple PP groups, then it has the same index. indices = [] for group in _PIPELINE_GLOBAL_RANKS: for i, r in enumerate(group): @@ -1169,7 +1169,7 @@ def get_pipeline_model_parallel_rank(): def get_pipeline_model_parallel_split_rank(): - """Return pipeline model parallel split rank.""" + """Return pipeline-model-parallel split rank.""" global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK @@ -1186,7 +1186,7 @@ def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_last_stage(ignore_virtual=False): - """Return True if in the last pipeline model-parallel stage, False otherwise.""" + """Return True if in the last pipeline-model-parallel stage, False otherwise.""" if not ignore_virtual: virtual_pipeline_model_parallel_world_size = ( get_virtual_pipeline_model_parallel_world_size() @@ -1334,8 +1334,7 @@ def get_data_parallel_src_rank(with_context_parallel=False): def get_pipeline_model_parallel_first_rank(): - """Return the global rank of the first process in the pipeline for the - current tensor parallel group""" + """Return the global rank of the first stage in the current rank's pipeline.""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" if isinstance(_PIPELINE_GLOBAL_RANKS[0], list): # I assume the first rank is the same for all pp groups right now. @@ -1347,17 +1346,17 @@ def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_last_rank(): - """Return the global rank of the last process in the pipeline for the - current tensor parallel group""" + """Return the global rank of the last stage in the current rank's pipeline.""" assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" last_rank_local = get_pipeline_model_parallel_world_size() - 1 return _PIPELINE_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): - """Return the global rank that follows the caller in the pipeline, for each pipeline group that - the rank is part of. If it's just part of one group, an int is returned, - otherwise a list of ints. + """Return the global rank that follows the caller in the pipeline, for each + pipeline-parallel group that the rank is part of. + + If it is just part of one group, an int is returned, otherwise a list of ints. """ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() @@ -1372,9 +1371,10 @@ def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_prev_rank(): - """Return the global rank that preceeds the caller in the pipeline, for each pipeline group that - the rank is part of. If it's just part of one group, an int is returned, - otherwise a list of ints. + """Return the global rank that precedes the caller in the pipeline, for each + pipeline-parallel group that the rank is part of. + + If it is just part of one group, an int is returned, otherwise a list of ints. """ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() @@ -1408,7 +1408,7 @@ def set_data_parallel_rank(rank): def get_data_parallel_rank(with_context_parallel=False): - """Return my rank for the data parallel group.""" + """Return caller's rank in the data-parallel group.""" global _MPU_DATA_PARALLEL_RANK if _MPU_DATA_PARALLEL_RANK is not None: return _MPU_DATA_PARALLEL_RANK @@ -1429,7 +1429,7 @@ def get_context_parallel_world_size(): def get_context_parallel_rank(): - """Return my rank for the context parallel group.""" + """Return caller's rank in the context-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_context_parallel_group()) else: @@ -1437,7 +1437,7 @@ def get_context_parallel_rank(): def get_tensor_and_context_parallel_world_size(): - """Return world size for the tensor and context parallel group""" + """Return world size for the tensor and context-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_world_size(group=get_tensor_and_context_parallel_group()) else: @@ -1445,7 +1445,7 @@ def get_tensor_and_context_parallel_world_size(): def get_tensor_and_context_parallel_rank(): - """Return my rank for the tensor and context parallel group.""" + """Return caller's rank in the joint tensor-model-parallel and context-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_tensor_and_context_parallel_group()) else: @@ -1453,7 +1453,7 @@ def get_tensor_and_context_parallel_rank(): def get_expert_model_parallel_world_size(): - """Return world size for the expert model parallel group""" + """Return world size for the expert-model-parallel group.""" if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE is not None: return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE if torch.distributed.is_available() and torch.distributed.is_initialized(): @@ -1479,7 +1479,7 @@ def get_tensor_and_expert_parallel_world_size(): def get_expert_model_parallel_rank(): - """Return my rank for the expert parallel group""" + """Return caller's rank in the expert-model-parallel group.""" if _MPU_EXPERT_MODEL_PARALLEL_RANK is not None: return _MPU_EXPERT_MODEL_PARALLEL_RANK if torch.distributed.is_available() and torch.distributed.is_initialized(): @@ -1492,7 +1492,7 @@ def get_expert_model_parallel_rank(): def get_data_modulo_expert_parallel_rank(with_context_parallel=False): - """Return my rank for the context parallel group.""" + """Return caller's rank in the context-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank( group=get_data_modulo_expert_parallel_group(with_context_parallel=with_context_parallel) @@ -1502,7 +1502,7 @@ def get_data_modulo_expert_parallel_rank(with_context_parallel=False): def get_tensor_and_expert_parallel_rank(): - """Return my rank for the tensor and expert parallel group""" + """Return caller's rank in the joint tensor- and expert-model-parallel group.""" if torch.distributed.is_available() and torch.distributed.is_initialized(): return torch.distributed.get_rank(group=get_tensor_and_expert_parallel_group()) else: @@ -1510,7 +1510,7 @@ def get_tensor_and_expert_parallel_rank(): def _set_global_memory_buffer(): - """Initialize global buffer""" + """Initialize global buffer.""" global _GLOBAL_MEMORY_BUFFER assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() @@ -1529,7 +1529,8 @@ def destroy_global_memory_buffer(): def get_all_ranks(): - """Retrieve the ranks for various parallel groups associated with the current rank.""" + """Get caller's rank in tensor-model-parallel, data-parallel, context-parallel, + pipeline-model-parallel and expert-model-parallel groups.""" ranks = [ get_tensor_model_parallel_rank(), get_data_parallel_rank(),