Skip to content

Commit

Permalink
Merge branch 'dnarayanan/skip_all_gather_first_iteration' into 'main'
Browse files Browse the repository at this point in the history
Remove all-gather before first iteration to not spread corrupted values

See merge request ADLR/megatron-lm!2414
  • Loading branch information
deepakn94 committed Dec 9, 2024
2 parents 37cd8f2 + 44b6480 commit d4e72c0
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
6 changes: 4 additions & 2 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,10 @@ def enable_forward_pre_hook(self):
self._make_forward_pre_hook()
)

def disable_forward_pre_hook(self):
def disable_forward_pre_hook(self, param_sync: bool = True):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
Expand All @@ -310,7 +311,8 @@ def disable_forward_pre_hook(self):
assert len(self.remove_forward_pre_hook_handles) == 0

# Force synchronize parameters.
self.start_param_sync(force_sync=True)
if param_sync:
self.start_param_sync(force_sync=True)

def _make_forward_pre_hook(self):
"""
Expand Down
34 changes: 15 additions & 19 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,12 @@ def start_grad_sync(self):
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG

# Stream synchronization logic of the CUDA streams that is
# implemented below for the gradient reduction within and across
# distributed optimizer instances.
# We use the following stream synchronization for the gradient reduction
# within and across DistOpt instances.

# Compute Stream - -------------Gradient Compute-------------------
# Comm. Stream - ------(wait for nccl)-----(wait for nccl)-------
# NCCL Stream - -------RS------ -------AR------
# Compute Stream: -------------Gradient compute-------------------
# Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream: -------RS------ -------AR------

# Use async communications only when overlap_grad_reduce is True.
async_op = (
Expand All @@ -287,13 +286,13 @@ def start_grad_sync(self):
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
# Assign a communication stream if we use partial DP DistOpt and we
# need to overlap communication
# Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream)

# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective
# gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
Expand All @@ -314,24 +313,21 @@ def start_grad_sync(self):
local_data_view,
bucket.grad_data,
op=reduce_op,
group=self.intra_distributed_optimizer_instance_group,
group=communication_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
)

# When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains
# With multiple DistOpt instances, we need to all-reduce across instances.
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):

# Create a new coalescing facility for the inter partial DP-AllReduce here
# Create a new coalescing manager for the inter-instance all-reduce.
with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
Expand Down Expand Up @@ -366,13 +362,13 @@ def finish_grad_sync(self):
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
# When using partial DP DistOpt, we don't need to sync as we launch comms on a separate
# communication stream
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
Expand Down
12 changes: 0 additions & 12 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Simple scaling."""
return self.get_loss_scale() * loss

def start_param_sync(self, model_index: int, *unused):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass

@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
Expand Down Expand Up @@ -1062,8 +1055,3 @@ def load_parameter_state(self, filename: str, *, update_legacy_format: bool = Fa
optimizer.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
)

def start_param_sync(self, model_index: int, *unused):
"""Start parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.start_param_sync(model_index, *unused)
40 changes: 37 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,10 +1113,10 @@ def enable_forward_pre_hook(model_chunks):
model_chunk.enable_forward_pre_hook()


def disable_forward_pre_hook(model_chunks):
def disable_forward_pre_hook(model_chunks, param_sync=True):
for model_chunk in model_chunks:
assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook()
model_chunk.disable_forward_pre_hook(param_sync=param_sync)


def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
Expand Down Expand Up @@ -1412,6 +1412,23 @@ def get_e2e_base_metrics():
with_stack=True)
prof.start()

start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")

# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
Expand Down Expand Up @@ -1456,7 +1473,24 @@ def get_e2e_base_metrics():
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
# why is skipped_iter ignored?

# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func

iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
Expand Down

0 comments on commit d4e72c0

Please sign in to comment.