From 4e3840535b1912222aa5e9c8c1705b947792f8da Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Fri, 23 Aug 2024 17:46:41 -0700 Subject: [PATCH] ADLR/megatron-lm!1874 - Overlap param all-gather with optimizer step and fix alignment of AGs across pipeline stages --- examples/gpt3/gpt_config.yaml | 4 +- megatron/core/optimizer/__init__.py | 193 ++++++++++++------ megatron/core/optimizer/distrib_optimizer.py | 58 ++++-- megatron/core/optimizer/optimizer.py | 34 ++- megatron/core/optimizer/optimizer_config.py | 8 + megatron/training/arguments.py | 37 +++- megatron/training/checkpointing.py | 3 +- megatron/training/training.py | 10 +- tests/functional_tests/jet_recipes/gpt.yaml | 1 + .../golden_values.json | 1 + .../model_config.yaml | 57 ++++++ tests/unit_tests/dist_checkpointing/utils.py | 1 + 12 files changed, 311 insertions(+), 96 deletions(-) create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/golden_values.json create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/model_config.yaml diff --git a/examples/gpt3/gpt_config.yaml b/examples/gpt3/gpt_config.yaml index 0e6408867c..443e4b79b8 100644 --- a/examples/gpt3/gpt_config.yaml +++ b/examples/gpt3/gpt_config.yaml @@ -215,9 +215,9 @@ fp16_lm_cross_entropy: False distributed_backend: nccl distributed_timeout_minutes: 10 overlap_grad_reduce: False -delay_grad_reduce: True +align_grad_reduce: True overlap_param_gather: False -delay_param_gather: False +align_param_gather: False scatter_gather_tensors_in_pipeline: True local_rank: null lazy_mpu_init: null diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 65f72ec8c8..d06911f1b9 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch @@ -42,10 +42,13 @@ def _get_param_groups( model_chunks: List[MegatronModule], - no_weight_decay_cond: Callable, - scale_lr_cond: Callable, + no_weight_decay_cond: Optional[Callable], + scale_lr_cond: Optional[Callable], lr_mult: float, - use_decoupled_learning_rate: bool, + lr: float, + min_lr: float, + decoupled_lr: Optional[float], + decoupled_min_lr: Optional[float], ) -> List[Dict]: """Create parameter groups for optimizer. @@ -57,18 +60,23 @@ def _get_param_groups( Args: model_chunks (List[MegatronModule]): model chunks to create parameter groups for. - no_weight_decay_cond (func): function to determine whether a parameter - should not perform weight decay. - scale_lr_cond (func): function to determine whether a parameter + no_weight_decay_cond (func, optional): function to determine whether a + parameter should not perform weight decay. + scale_lr_cond (func, optional): function to determine whether a parameter should have a scaled learning rate. lr_mult (float): learning rate multiplier for parameters that satisfy scale_lr_cond. - use_decoupled_learning_rate (bool): true if using decoupled learning rate. + lr (float): learning rate. + min_lr (float): minimum learning rate. + decoupled_lr (Optional[float]): optional decoupled learning rate. + decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate. Returns: List of parameter groups. """ + use_decoupled_learning_rate = decoupled_lr is not None + # Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params. params_map = {} for model_chunk in model_chunks: @@ -113,15 +121,22 @@ def _get_param_groups( param_groups = [] for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items(): assert len(params) > 0 - param_groups.append( - { - 'params': params, - 'wd_mult': wd_mult, - 'lr_mult': _lr_mult, - 'is_expert_parallel': is_expert_parallel, - 'is_decoupled_lr': is_decoupled_lr, - } - ) + param_group = { + 'params': params, + 'wd_mult': wd_mult, + 'lr_mult': _lr_mult, + 'is_expert_parallel': is_expert_parallel, + 'is_decoupled_lr': is_decoupled_lr, + } + param_groups.append(param_group) + + param_groups = _update_min_and_max_lr_in_param_groups( + param_groups, + lr=lr, + min_lr=min_lr, + decoupled_lr=decoupled_lr, + decoupled_min_lr=decoupled_min_lr, + ) return param_groups @@ -165,6 +180,56 @@ def _update_min_and_max_lr_in_param_groups( return param_groups +def _get_param_groups_and_buffers( + model_chunks: List[MegatronModule], + model_chunk_offset: int, + config: OptimizerConfig, + no_weight_decay_cond: Optional[Callable], + scale_lr_cond: Optional[Callable], + lr_mult: float, + filter_fn: Callable, + buffer_name: str, +) -> Tuple[List[Dict], Dict[int, ParamAndGradBuffer]]: + """Returns parameter groups and buffer for optimizer. + + Args: + model_chunks (List[MegatronModule]): model chunks to create parameter + groups for. + model_chunk_offset (int): offset of model_chunks in global model_chunks list. + config (OptimizerConfig): optimizer configuration object. + no_weight_decay_cond (func, optional): function to determine whether a + parameter should not perform weight decay. + scale_lr_cond (func, optional): function to determine whether a parameter + should have a scaled learning rate. + lr_mult (float): learning rate multiplier for parameters that + satisfy scale_lr_cond. + lr (float): learning rate. + min_lr (float): minimum learning rate. + filter_fn (callable): filtering function for param_groups. + buffer_name (str): name of buffer. + + Returns: + List of parameter groups and dictionary of model chunk IDs to buffers. + """ + param_groups = _get_param_groups( + model_chunks, + no_weight_decay_cond, + scale_lr_cond, + lr_mult, + lr=config.lr, + min_lr=config.min_lr, + decoupled_lr=config.decoupled_lr, + decoupled_min_lr=config.decoupled_min_lr, + ) + param_groups = list(filter(filter_fn, param_groups)) + buffers = {} + for model_chunk_idx, model_chunk in enumerate(model_chunks): + if hasattr(model_chunk, buffer_name): + buffers[model_chunk_idx + model_chunk_offset] = getattr(model_chunk, buffer_name) + + return param_groups, buffers + + def _get_megatron_optimizer_based_on_param_groups( config: OptimizerConfig, param_groups: List, @@ -173,6 +238,7 @@ def _get_megatron_optimizer_based_on_param_groups( data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_idx: Optional[int] = None, + overlap_param_gather_with_optimizer_step: bool = False, ) -> MegatronOptimizer: """Get Megatron optimizer based on parameter groups. @@ -186,6 +252,8 @@ def _get_megatron_optimizer_based_on_param_groups( group for distributed optimizer. Defaults to None. data_parallel_group_idx (int, optional): data-parallel group index for distributed optimizer. Defaults to None. + overlap_param_gather_with_optimizer_step (bool, optional): if true, overlap parameter + all-gather with optimizer step if using distributed optimizer. Defaults to False. Returns: Instance of MegatronOptimizer. @@ -255,6 +323,7 @@ def init_state_fn(opt): data_parallel_group=data_parallel_group, data_parallel_group_gloo=data_parallel_group_gloo, data_parallel_group_idx=data_parallel_group_idx, + overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, ) else: optimizer = Float16OptimizerWithFloat16Params(*optimizer_args) @@ -294,48 +363,56 @@ def get_megatron_optimizer( log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}') - # Collect param groups. - param_groups = _get_param_groups( - model_chunks, - no_weight_decay_cond, - scale_lr_cond, - lr_mult, - use_decoupled_learning_rate=config.decoupled_lr is not None, - ) - param_groups = _update_min_and_max_lr_in_param_groups( - param_groups, - lr=config.lr, - min_lr=config.min_lr, - decoupled_lr=config.decoupled_lr, - decoupled_min_lr=config.decoupled_min_lr, - ) - - # Collect grad buffers for distributed optimizer. - per_model_buffers = {} - per_model_ep_buffers = {} - for model_idx, model_chunk in enumerate(model_chunks): - if hasattr(model_chunk, 'buffers'): - per_model_buffers[model_idx] = model_chunk.buffers - per_model_ep_buffers[model_idx] = model_chunk.expert_parallel_buffers - - # Split param groups into dense and MoE params (since data-parallel groups for MoE - # parameters can be different with expert parallelism). - dense_param_groups = list(filter(lambda g: not g['is_expert_parallel'], param_groups)) - moe_param_groups = list(filter(lambda g: g['is_expert_parallel'], param_groups)) - - # Create optimizers. + # Separate out first model chunk if overlapping param AG with optimizer step. + if config.overlap_param_gather_with_optimizer_step: + all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]] + overlap_param_gather_with_optimizer_step_flags = [True, False] + else: + all_dense_model_chunks = [model_chunks] + overlap_param_gather_with_optimizer_step_flags = [False] model_parallel_rank = torch.distributed.get_rank(mpu.get_model_parallel_group()) - optimizers = [ - _get_megatron_optimizer_based_on_param_groups( - config, - param_groups=dense_param_groups, - per_model_buffers=per_model_buffers, - model_parallel_group=mpu.get_model_parallel_group(), - data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True), - data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True), - data_parallel_group_idx=model_parallel_rank, + + optimizers = [] + model_chunk_offset = 0 + for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip( + all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags + ): + param_groups, buffers = _get_param_groups_and_buffers( + dense_model_chunks, + model_chunk_offset=model_chunk_offset, + config=config, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + filter_fn=lambda g: not g['is_expert_parallel'], + buffer_name='buffers', + ) + optimizers.append( + _get_megatron_optimizer_based_on_param_groups( + config, + param_groups=param_groups, + per_model_buffers=buffers, + model_parallel_group=mpu.get_model_parallel_group(), + data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True), + data_parallel_group_gloo=mpu.get_data_parallel_group_gloo( + with_context_parallel=True + ), + data_parallel_group_idx=model_parallel_rank, + overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, + ) ) - ] + model_chunk_offset += 1 + + moe_param_groups, moe_buffers = _get_param_groups_and_buffers( + model_chunks, + model_chunk_offset=0, + config=config, + no_weight_decay_cond=no_weight_decay_cond, + scale_lr_cond=scale_lr_cond, + lr_mult=lr_mult, + filter_fn=lambda g: g['is_expert_parallel'], + buffer_name='expert_parallel_buffers', + ) if len(moe_param_groups) > 0: model_parallel_world_size = torch.distributed.get_world_size(mpu.get_model_parallel_group()) expert_parallel_rank = mpu.get_expert_model_parallel_rank() @@ -343,7 +420,7 @@ def get_megatron_optimizer( _get_megatron_optimizer_based_on_param_groups( config, param_groups=moe_param_groups, - per_model_buffers=per_model_ep_buffers, + per_model_buffers=moe_buffers, model_parallel_group=mpu.get_model_parallel_group(with_expert_parallel=True), data_parallel_group=mpu.get_data_modulo_expert_parallel_group( with_context_parallel=True diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index b42b493fc4..c211619d0e 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -21,7 +21,7 @@ HAVE_APEX_OR_TE = False -from .. import parallel_state, tensor_parallel +from .. import tensor_parallel from ..config_logger import has_config_logger_enabled, log_config_to_disk from ..dist_checkpointing import ShardedTensor from ..dist_checkpointing.dict_utils import nested_values @@ -93,7 +93,7 @@ def _build_model_gbuf_param_range_map( buffer shard ranges, specific to each data-parallel (DP) rank's set of 'owned' parameters. Each grad buffer (padded to be an even multiple of DP-world-size) is conceptually divided into DP-world-size - contiguous regions, where each DP rank 'owns' a contiguous regions. + contiguous regions, where each DP rank 'owns' a contiguous region. Ownership in this sense means DP rank is responsible for reducing the relevant subset of grads, and updating the relevant subset of params. @@ -393,6 +393,7 @@ def __init__( data_parallel_group: torch.distributed.ProcessGroup, data_parallel_group_gloo: torch.distributed.ProcessGroup, data_parallel_group_idx: int, + overlap_param_gather_with_optimizer_step: bool = False, ): """ Distributed optimizer, for all data types (fp16, bf16, and fp32). @@ -422,6 +423,8 @@ def __init__( (used in checkpoint loading and saving). data_parallel_group_idx (int): index in data-parallel group (used by distributed checkpointing logic). + overlap_param_gather_with_optimizer_step (bool, optional): if true, overlap parameter + all-gather with optimizer step. Defaults to False. """ if has_config_logger_enabled(config): @@ -516,6 +519,7 @@ def __init__( self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map) self.overlap_param_gather = self.config.overlap_param_gather + self.overlap_param_gather_with_optimizer_step = overlap_param_gather_with_optimizer_step self.remove_pre_hook_handle = None if self.overlap_param_gather: self.enable_pre_hook() @@ -547,6 +551,7 @@ def disable_pre_hook(self): # Make sure all-gathers are completed as needed. self._reset_metadata_and_sync_gather_all_model_params(force_sync=True) + self.update_successful = False def _get_model_param_range_map(self, param: torch.nn.Parameter): """ @@ -1490,7 +1495,14 @@ def zero_grad(self, set_to_none: bool = True): # pre-hook when this all-gather finishes (to ensure that the communication # kernels don't head-of-line block the compute kernels since we run with # CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism). - if self.overlap_param_gather: + # If aligning param all-gather across pipeline stages, all-gather is dispatched + # by start_param_sync calls in core/pipeline_parallelism/schedules.py. + # If overlapping param all-gather with optimizer step, then all-gather has + # already been dispatched in optimizer step. + skip_dispatch = ( + self.config.align_param_gather or self.overlap_param_gather_with_optimizer_step + ) + if self.overlap_param_gather and not skip_dispatch: self._dispatch_gather_model_params(all_gather_handle_index=0) def _get_model_param_buffer_dp_views(self): @@ -1587,25 +1599,47 @@ def hook(module, *unused): # non-expert params. if param in self.param_to_all_gather_handle_index_map: all_gather_handle_index = self.param_to_all_gather_handle_index_map[param] - self._finish_param_sync_helper(all_gather_handle_index) + # If aligning param all-gather across pipeline stages, all-gather is dispatched + # by start_param_sync calls in core/pipeline_parallelism/schedules.py. + # If overlapping param all-gather with optimizer step, then all-gather has + # already been dispatched in optimizer step. + skip_dispatch = ( + self.config.align_param_gather + or self.overlap_param_gather_with_optimizer_step + ) + self._finish_param_sync_helper( + all_gather_handle_index, skip_dispatch=skip_dispatch + ) return hook - def finish_param_sync(self, model_index: int, *unused): + def start_param_sync(self, model_index: int, *unused, force_dispatch: bool = False): """ - Finishes all necessary param syncs for the model_index'th model chunk. + Starts all necessary param syncs for the model_index'th model chunk. Args: model_index (int): index of model chunk to synchronize params. + force_dispatch (bool, optional): force dispatch regardless of other settings. """ if model_index not in self.model_index_to_all_gather_handle_index_map: return - all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index] - for all_gather_handle_index in all_gather_handle_indices: - self._finish_param_sync_helper(all_gather_handle_index) + if self.overlap_param_gather_with_optimizer_step and not force_dispatch: + return - def _finish_param_sync_helper(self, all_gather_handle_index: int): + # If overlapping param AG with optimizer step, AG has already been dispatched. + if self.update_successful: + all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index] + with torch.distributed._coalescing_manager( + group=self.data_parallel_group, async_ops=self.overlap_param_gather + ) as cm: + for all_gather_handle_index in all_gather_handle_indices: + self._dispatch_gather_model_params(all_gather_handle_index) + if self.overlap_param_gather: + for all_gather_handle_index in all_gather_handle_indices: + self.all_gather_handles[all_gather_handle_index] = cm + + def _finish_param_sync_helper(self, all_gather_handle_index: int, skip_dispatch: bool = False): """ Waits on all_gather_handle if necessary, then dispatches the next all-gather as necessary. @@ -1625,7 +1659,7 @@ def _finish_param_sync_helper(self, all_gather_handle_index: int): # (since we run with CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence # parallelism). next_all_gather_handle_index = all_gather_handle_index + 1 - if next_all_gather_handle_index < self.num_all_gather_handles: + if next_all_gather_handle_index < self.num_all_gather_handles and not skip_dispatch: self._dispatch_gather_model_params(next_all_gather_handle_index) def _collect_main_grad_data_for_unscaling(self): @@ -1744,7 +1778,7 @@ def _reset_metadata_and_sync_gather_all_model_params(self, force_sync: bool): # is explicitly set to True (e.g., if we are going to turn off all-gather overlapping for # validation / test iterations). if not self.overlap_param_gather or force_sync: - for all_gather_handle_index in range(self.num_all_gather_handles): + for all_gather_handle_index in range(len(self.all_gather_handles)): self._dispatch_gather_model_params(all_gather_handle_index, force_sync=force_sync) @torch.no_grad() diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 2a48c12d37..9b998c14ad 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -154,6 +154,7 @@ def step_with_ready_grads(self) -> bool: @torch.no_grad() def get_grad_norm(self): + """Compute and return grad norm.""" grads_for_norm = self.get_main_grads_for_grad_norm() total_norm = get_grad_norm_fp32( grads_for_norm, model_parallel_group=self.get_model_parallel_group() @@ -161,7 +162,7 @@ def get_grad_norm(self): return total_norm def clip_grad_norm(self, clip_grad: float) -> float: - """Compute grad norm.""" + """Compute and return grad norm, also clip grads.""" params = self.get_parameters() grads_for_norm = self.get_main_grads_for_grad_norm() grad_norm = get_grad_norm_fp32( @@ -177,6 +178,7 @@ def count_zeros(self) -> float: @abstractmethod def zero_grad(self, set_to_none: bool = True): + """Zero gradients and prepare for next forward pass.""" pass @abstractmethod @@ -191,9 +193,9 @@ def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: """Simple scaling.""" return self.get_loss_scale() * loss - def finish_param_sync(self, model_index: int): + def start_param_sync(self, model_index: int, *unused): """ - Finish parameter synchronization for all optimizers. + Start parameter synchronization for all optimizers. This is a no-op for all non-distributed optimizers. """ pass @@ -209,10 +211,12 @@ def reload_model_params(self): @abstractmethod def state_dict(self): + """Return state_dict.""" pass @abstractmethod def load_state_dict(self, state_dict): + """Load pass-in `state_dict`.""" pass # Promote state so it can be retrieved or set via @@ -857,6 +861,7 @@ def __iter__(self): yield (idx, inner_key) def items(self): + """Return generator over underlying items.""" for idx, inner_dict in enumerate(self._inner_dicts): for inner_key, value in inner_dict.items(): yield (idx, inner_key), value @@ -873,10 +878,14 @@ class ChainedOptimizer(MegatronOptimizer): """ def __init__(self, chained_optimizers: List[MegatronOptimizer]): + self.config = getattr(chained_optimizers[0], 'config', None) + for optimizer in chained_optimizers[1:]: + assert self.config == getattr(optimizer, 'config', None) self.chained_optimizers = chained_optimizers @property def param_groups(self) -> List[dict]: + """Get param_groups aggregated over underlying optimizers.""" param_groups = [] for optimizer in self.chained_optimizers: param_groups += optimizer.param_groups @@ -940,12 +949,16 @@ def prepare_grads(self) -> bool: def step_with_ready_grads(self) -> bool: """Step the optimizer with ready gradients, return successful.""" success = True - for optimizer in self.chained_optimizers: + for optimizer_idx, optimizer in enumerate(self.chained_optimizers): success &= optimizer.step_with_ready_grads() + if self.config.overlap_param_gather_with_optimizer_step and optimizer_idx == 0: + assert success + optimizer.start_param_sync(model_index=0, force_dispatch=True) return success def disable_pre_hook(self): + """Disable pre-hooks for underlying distributed optimizers.""" for optimizer in self.chained_optimizers: if ( not optimizer.config.use_distributed_optimizer @@ -958,6 +971,7 @@ def disable_pre_hook(self): optimizer.disable_pre_hook() def enable_pre_hook(self): + """Enable pre-hooks for underlying distributed optimizers.""" for optimizer in self.chained_optimizers: if ( not optimizer.config.use_distributed_optimizer @@ -1028,7 +1042,7 @@ def save_parameter_state(self, filename: str): if save_states: torch.save(states, filename) - def load_parameter_state(self, filename: str): + def load_parameter_state(self, filename: str, *, update_legacy_format: bool = False): """Load the distributed parameter states of all optimizers from a file. Args: @@ -1044,9 +1058,11 @@ def load_parameter_state(self, filename: str): states = torch.load(filename) state_dict = states[idx] if states else None - optimizer.load_parameter_state_from_dp_zero(state_dict) + optimizer.load_parameter_state_from_dp_zero( + state_dict, update_legacy_format=update_legacy_format + ) - def finish_param_sync(self, model_index: int): - """Finish parameter synchronization for all optimizers.""" + def start_param_sync(self, model_index: int, *unused): + """Start parameter synchronization for all optimizers.""" for optimizer in self.chained_optimizers: - optimizer.finish_param_sync(model_index) + optimizer.start_param_sync(model_index, *unused) diff --git a/megatron/core/optimizer/optimizer_config.py b/megatron/core/optimizer/optimizer_config.py index 8b8413a36a..31c67e14f1 100644 --- a/megatron/core/optimizer/optimizer_config.py +++ b/megatron/core/optimizer/optimizer_config.py @@ -100,6 +100,14 @@ class OptimizerConfig: overlap_param_gather: bool = False """If true, overlap param all-gather with forward compute in distributed optimizer.""" + overlap_param_gather_with_optimizer_step: bool = False + """If true, overlap param all-gather of first bucket with optimizer step.""" + + align_param_gather: bool = False + """If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each + PP stage will independently launch as needed. + """ + ################ # Miscellaneous ################ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 46f573a2b2..c39c19b498 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -161,6 +161,9 @@ def validate_args(args, defaults={}): # Load saved args from Retro (if applicable). load_retro_args(args) + # Set args.use_dist_ckpt from args.ckpt_format. + update_use_dist_ckpt(args) + if args.encoder_tensor_model_parallel_size > 0: assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined." assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0 @@ -208,7 +211,6 @@ def validate_args(args, defaults={}): args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size assert args.pipeline_model_parallel_size > 0 - if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' @@ -293,10 +295,24 @@ def validate_args(args, defaults={}): assert args.use_distributed_optimizer, \ '--overlap-param-gather only supported with distributed optimizer' assert args.overlap_grad_reduce, \ - '--overlap-grad-reduce should be turned on when using --overlap-param-gather' + 'Must use --overlap-param-gather with --overlap-grad-reduce' assert not args.use_legacy_models, \ '--overlap-param-gather only supported with MCore models' + if args.overlap_param_gather_with_optimizer_step: + assert args.use_distributed_optimizer, \ + '--overlap-param-gather-with-optimizer-step only supported with distributed optimizer' + assert args.overlap_param_gather, \ + 'Must use --overlap-param-gather-with-optimizer-step with --overlap-param-gather' + assert args.virtual_pipeline_model_parallel_size is not None, \ + '--overlap-param-gather-with-optimizer-step only supported with interleaved pipeline parallelism' + assert not args.use_dist_ckpt, \ + '--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet' + + if args.align_param_gather: + assert args.virtual_pipeline_model_parallel_size is not None, \ + '--align-param-gather only supported with interleaved pipeline parallelism' + # Parameters dtype. args.params_dtype = torch.float if args.fp16: @@ -516,9 +532,6 @@ def validate_args(args, defaults={}): assert args.pipeline_model_parallel_size == 1, \ "retro currently does not support pipeline parallelism." - # Set args.use_dist_ckpt from args.ckpt_format. - update_use_dist_ckpt(args) - if args.decoupled_lr is not None or args.decoupled_min_lr is not None: assert not args.use_legacy_models, \ '--decoupled-lr and --decoupled-min-lr is not supported in legacy models.' @@ -1498,17 +1511,21 @@ def _add_distributed_args(parser): 'weight gradient computation of vocabulary projection is deferred, defaults to 0 which' 'means all the micro-batches are deferred. Invalid if `defer-embedding-wgrad-compute`' 'is not set') - group.add_argument('--no-delay-grad-reduce', action='store_false', - help='If not set, delay / synchronize grad reductions in all but first PP stage.', - dest='delay_grad_reduce') + group.add_argument('--no-align-grad-reduce', action='store_false', + help='If not set, all PP stages will launch gradient reduces simultaneously. ' + 'Otherwise, each PP stage will independently launch as needed.', + dest='align_grad_reduce') group.add_argument('--ddp-bucket-size', type=int, default=None, help='Bucket size for data-parallel communication') group.add_argument('--ddp-average-in-collective', action='store_true', default=False, help='If set, average directly in data-parallel communication collective.') group.add_argument('--overlap-param-gather', action='store_true', default=False, help='If set, overlap param all-gather in distributed optimizer.') - group.add_argument('--delay-param-gather', action='store_true', - default=False, help='If set, delay / synchronize param all-gathers in all but first PP stage.') + group.add_argument('--overlap-param-gather-with-optimizer-step', action='store_true', + default=False, help='If set, overlap param all-gather of first bucket with optimizer step.') + group.add_argument('--align-param-gather', action='store_true', default=False, + help='If set, all PP stages will launch param all-gathers simultaneously. ' + 'Otherwise, each PP stage will independently launch as needed.') group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', help='If not set, use scatter/gather to optimize communication of tensors in pipeline.', dest='scatter_gather_tensors_in_pipeline') diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 9319fe09ee..fca80acc91 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1082,7 +1082,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri optim_checkpoint_name = \ get_distributed_optimizer_checkpoint_name( model_checkpoint_name) - optimizer.load_parameter_state(optim_checkpoint_name, update_legacy_format=args.ckpt_convert_update_legacy_dist_opt_format) + optimizer.load_parameter_state(optim_checkpoint_name, + update_legacy_format=args.ckpt_convert_update_legacy_dist_opt_format) # Load scheduler. if opt_param_scheduler is not None: diff --git a/megatron/training/training.py b/megatron/training/training.py index 75a5b0bff7..b7e2230ed2 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -4,6 +4,7 @@ import dataclasses from datetime import datetime +import functools import gc import logging import math @@ -493,12 +494,13 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad, bucket_size=args.ddp_bucket_size, average_in_collective=args.ddp_average_in_collective) + overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False) model = [DDP(config, ddp_config, model_chunk, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0)) + disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step) for (model_chunk_idx, model_chunk) in enumerate(model)] # Broadcast params from data parallel src rank to other data parallel ranks. @@ -1067,12 +1069,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, config.no_sync_func = [model_chunk.no_sync for model_chunk in model] if len(model) == 1: config.no_sync_func = config.no_sync_func[0] - if args.delay_grad_reduce: + if args.align_grad_reduce: config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] if len(model) == 1: config.grad_sync_func = config.grad_sync_func[0] - if args.overlap_param_gather and args.delay_param_gather: - config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x) + if args.overlap_param_gather and args.align_param_gather: + config.param_sync_func = [functools.partial(optimizer.start_param_sync, model_index) for model_index in range(len(model))] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] diff --git a/tests/functional_tests/jet_recipes/gpt.yaml b/tests/functional_tests/jet_recipes/gpt.yaml index 4ee46eaf7e..d7d14eae4e 100644 --- a/tests/functional_tests/jet_recipes/gpt.yaml +++ b/tests/functional_tests/jet_recipes/gpt.yaml @@ -55,6 +55,7 @@ products: - gpt3_mr_mcore_te_tp1_pp4_vp1_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G + - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_untied_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_decoupled_lr_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_calculate_per_token_loss_dgx_a100_1N8G diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/golden_values.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/golden_values.json new file mode 100644 index 0000000000..549ceb7eab --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/golden_values.json @@ -0,0 +1 @@ +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.82005, 10.87449, 10.87799, 10.79508, 10.68166, 10.59514, 10.10042, 10.21238, 10.13865, 9.80879]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1559.0, 1719.0, 1857.0, 1746.0, 1883.0, 1738.0, 1475.0, 1851.0, 2303.0, 2258.0]}, "iteration_timing_avg": 0.12873676470588236} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..7cc5c29ce9 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,57 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 4 + --num-layers-per-virtual-pipeline-stage: 1 + --use-distributed-optimizer: true + --overlap-grad-reduce: true + --overlap-param-gather: true + --overlap-param-gather-with-optimizer-step: true + --align-param-gather: true + --check-weight-hash-across-dp-replicas-interval: 10 + --ckpt-fully-parallel-load: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: regular \ No newline at end of file diff --git a/tests/unit_tests/dist_checkpointing/utils.py b/tests/unit_tests/dist_checkpointing/utils.py index e58b7f0822..e4a007aa75 100644 --- a/tests/unit_tests/dist_checkpointing/utils.py +++ b/tests/unit_tests/dist_checkpointing/utils.py @@ -54,6 +54,7 @@ def init_basic_mock_args(args, tp, pp, bf16=True): args.bf16 = bf16 args.accumulate_allreduce_grads_in_fp32 = False args.overlap_grad_reduce = False + args.overlap_param_gather_with_optimizer_step = False args.use_distributed_optimizer = True args.ddp_bucket_size = None args.check_for_nan_in_loss_and_grad = False