From 44b6480511f194ccb3943fbf590bc146e6612160 Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Mon, 9 Dec 2024 11:10:20 -0800 Subject: [PATCH] ADLR/megatron-lm!2414 - Remove all-gather before first iteration to not spread corrupted values --- .../distributed/distributed_data_parallel.py | 6 ++- .../core/distributed/param_and_grad_buffer.py | 34 +++++++--------- megatron/core/optimizer/optimizer.py | 12 ------ megatron/training/training.py | 40 +++++++++++++++++-- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 3a23426eca..6b3d50bd6e 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -297,9 +297,10 @@ def enable_forward_pre_hook(self): self._make_forward_pre_hook() ) - def disable_forward_pre_hook(self): + def disable_forward_pre_hook(self, param_sync: bool = True): """ Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. """ assert self.use_forward_hook # De-register forward pre-hook for all sub-modules. @@ -310,7 +311,8 @@ def disable_forward_pre_hook(self): assert len(self.remove_forward_pre_hook_handles) == 0 # Force synchronize parameters. - self.start_param_sync(force_sync=True) + if param_sync: + self.start_param_sync(force_sync=True) def _make_forward_pre_hook(self): """ diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 00c8fdd69d..5095a7c7f3 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -270,13 +270,12 @@ def start_grad_sync(self): if self.ddp_config.average_in_collective: reduce_op = torch.distributed.ReduceOp.AVG - # Stream synchronization logic of the CUDA streams that is - # implemented below for the gradient reduction within and across - # distributed optimizer instances. + # We use the following stream synchronization for the gradient reduction + # within and across DistOpt instances. - # Compute Stream - -------------Gradient Compute------------------- - # Comm. Stream - ------(wait for nccl)-----(wait for nccl)------- - # NCCL Stream - -------RS------ -------AR------ + # Compute Stream: -------------Gradient compute------------------- + # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)------- + # NCCL Stream: -------RS------ -------AR------ # Use async communications only when overlap_grad_reduce is True. async_op = ( @@ -287,13 +286,13 @@ def start_grad_sync(self): self.ddp_config.num_distributed_optimizer_instances > 1 and self.ddp_config.overlap_grad_reduce ): - # Assign a communication stream if we use partial DP DistOpt and we - # need to overlap communication + # Assign a communication stream if we have multiple DistOpt instances and we + # need to overlap communication. stream_context = torch.cuda.stream(self.communication_stream) # The RS/AR communication stream needs to wait for the default stream # to complete its gradient computation before launching the next - # gradient reduction collective + # gradient reduction collective. self.communication_stream.wait_stream(torch.cuda.default_stream()) else: stream_context = nullcontext() @@ -314,24 +313,21 @@ def start_grad_sync(self): local_data_view, bucket.grad_data, op=reduce_op, - group=self.intra_distributed_optimizer_instance_group, + group=communication_group, async_op=async_op, ) else: torch.distributed.all_reduce( - bucket.grad_data, - op=reduce_op, - group=self.data_parallel_group, - async_op=async_op, + bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op ) - # When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains + # With multiple DistOpt instances, we need to all-reduce across instances. if ( self.ddp_config.use_distributed_optimizer and self.ddp_config.num_distributed_optimizer_instances > 1 ): - # Create a new coalescing facility for the inter partial DP-AllReduce here + # Create a new coalescing manager for the inter-instance all-reduce. with stream_context, _coalescing_manager( self.inter_distributed_optimizer_instance_group, async_ops=async_op ) as cm: @@ -366,13 +362,13 @@ def finish_grad_sync(self): communication call to complete. When ddp_config.overlap_grad_reduce is set to False, makes synchronous call. """ - # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. self.param_gather_dispatched = False + # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. if not self.ddp_config.overlap_grad_reduce: self.start_grad_sync() return - # When using partial DP DistOpt, we don't need to sync as we launch comms on a separate - # communication stream + # When using multiple DistOpt instances, we don't need to sync here as we launch + # communications on a separate communication stream. if self.ddp_config.num_distributed_optimizer_instances > 1: torch.cuda.default_stream().wait_stream(self.communication_stream) return diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index c48bb580d8..a0f35065ab 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -213,13 +213,6 @@ def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: """Simple scaling.""" return self.get_loss_scale() * loss - def start_param_sync(self, model_index: int, *unused): - """ - Start parameter synchronization for all optimizers. - This is a no-op for all non-distributed optimizers. - """ - pass - @abstractmethod def reload_model_params(self): """Refreshes any internal state from the current model parameters. @@ -1062,8 +1055,3 @@ def load_parameter_state(self, filename: str, *, update_legacy_format: bool = Fa optimizer.load_parameter_state_from_dp_zero( state_dict, update_legacy_format=update_legacy_format ) - - def start_param_sync(self, model_index: int, *unused): - """Start parameter synchronization for all optimizers.""" - for optimizer in self.chained_optimizers: - optimizer.start_param_sync(model_index, *unused) diff --git a/megatron/training/training.py b/megatron/training/training.py index cffde8830e..741a8bf0a6 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1113,10 +1113,10 @@ def enable_forward_pre_hook(model_chunks): model_chunk.enable_forward_pre_hook() -def disable_forward_pre_hook(model_chunks): +def disable_forward_pre_hook(model_chunks, param_sync=True): for model_chunk in model_chunks: assert isinstance(model_chunk, DDP) - model_chunk.disable_forward_pre_hook() + model_chunk.disable_forward_pre_hook(param_sync=param_sync) def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, @@ -1412,6 +1412,23 @@ def get_e2e_base_metrics(): with_stack=True) prof.start() + start_iteration = iteration + # Disable forward pre-hook to start training to ensure that errors in checkpoint loading + # or random initialization don't propagate to all ranks in first all-gather (which is a + # no-op if things work correctly). + if args.use_distributed_optimizer and args.overlap_param_gather: + disable_forward_pre_hook(model, param_sync=False) + # Also remove param_sync_func temporarily so that sync calls made in + # `forward_backward_func` are no-ops. + param_sync_func = config.param_sync_func + config.param_sync_func = None + # Also, check weight hash across DP replicas to be very pedantic. + if args.check_weight_hash_across_dp_replicas_interval is not None: + assert check_param_hashes_across_dp_replicas(model, cross_check=True), \ + "Parameter hashes not matching across DP replicas" + torch.distributed.barrier() + print_rank_0(f">>> Weight hashes match after {iteration} iterations...") + # Run training iterations till done. while iteration < args.train_iters: if args.profile and torch.distributed.get_rank() in args.profile_ranks: @@ -1456,7 +1473,24 @@ def get_e2e_base_metrics(): checkpointing_context, train_data_iterator=train_data_iterator) if should_exit: break - # why is skipped_iter ignored? + + # Enable forward pre-hooks after first set of forward and backward passes. + # When running in fp16, skip all NaN iterations until steady-state loss scaling value + # is reached. + if iteration == start_iteration: + if skipped_iter: + # Only enable forward pre-hook after a training step has successfully run. Relevant + # for fp16 codepath where first XX iterations are skipped until steady-state loss + # scale value is reached. + start_iteration = iteration + 1 + else: + # Enable forward pre-hook after training step has successfully run. All subsequent + # forward passes will use the forward pre-hook / `param_sync_func` in + # `forward_backward_func`. + if args.use_distributed_optimizer and args.overlap_param_gather: + enable_forward_pre_hook(model) + config.param_sync_func = param_sync_func + iteration += 1 batch_size = mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \