Skip to content

Commit

Permalink
Merge branch 'grad_overlap_with_pp' into 'main'
Browse files Browse the repository at this point in the history
Enable grad overlap for pipeline parallelism

See merge request ADLR/megatron-lm!791
  • Loading branch information
jaredcasper committed Sep 29, 2023
2 parents 8737bc1 + 299d8a5 commit c294408
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
9 changes: 6 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def validate_args(args, defaults={}):
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)

# Overlapping grad reduce only supported without pipeline parallelism right now.
# Overlapping grad reduce not supported with interleaved PP right now.
if args.overlap_grad_reduce:
assert args.pipeline_model_parallel_size == 1
assert args.virtual_pipeline_model_parallel_size is None

if args.dataloader_type is None:
args.dataloader_type = 'single'
Expand Down Expand Up @@ -1012,8 +1012,11 @@ def _add_distributed_args(parser):
help='Timeout minutes for torch.distributed.')
group.add_argument('--overlap-grad-reduce', action='store_true',
default=False, help='If set, overlap DDP grad reduce.')
group.add_argument('--no-delay-grad-reduce', action='store_false',
help='If not set, delay grad reduction in all but first PP stage.',
dest='delay_grad_reduce')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
help='If not set, use scatter/gather to optimize communication of tensors in pipeline.',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--use-ring-exchange-p2p', action='store_true',
default=False, help='If set, use custom-built ring exchange '
Expand Down
20 changes: 6 additions & 14 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import core
from megatron.core import parallel_state
Expand Down Expand Up @@ -315,8 +314,6 @@ def forward_backward_no_pipelining(
config = get_model_config(model)

no_sync_func = config.no_sync_func
if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext

Expand Down Expand Up @@ -386,15 +383,6 @@ def forward_backward_pipelining_with_interleaving(

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model):

def multi_no_sync():
stack = contextlib.ExitStack()
for chunk in model:
stack.enter_context(chunk.no_sync())
return stack

no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
Expand Down Expand Up @@ -1057,8 +1045,6 @@ def forward_backward_pipelining_without_interleaving(

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
Expand Down Expand Up @@ -1209,6 +1195,12 @@ def enable_grad_sync():
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if num_warmup_microbatches == 0 and last_iteration:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()

input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
Expand Down
10 changes: 10 additions & 0 deletions megatron/model/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ def done(self):
for bucket in self.buckets:
bucket.done()

def grad_sync(self):
"""Synchronize grads."""
for bucket in self.buckets:
bucket.communicate()

def mark_grad_as_done(self, param: torch.nn.Parameter):
"""
When the number of microbatches is greater than 1, we only want
Expand Down Expand Up @@ -428,6 +433,11 @@ def no_sync(self):
for grad_buffer in self.grad_buffers.values():
grad_buffer.is_last_microbatch = True

def grad_sync(self, *unused):
"""Method to dispatch grad sync operations."""
for grad_buffer in self.grad_buffers.values():
grad_buffer.grad_sync()

def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
Expand Down
7 changes: 6 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
config.timers = timers
# TODO: Remove this once we move DDP to Core.
if len(model) == 1 and isinstance(model[0], DDP) and \
args.pipeline_model_parallel_size == 1:
args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
if args.delay_grad_reduce:
config.grad_sync_func = model[0].grad_sync
config.no_sync_func = model[0].no_sync

timers('interval-time', log_level=0).start(barrier=True)
Expand Down

0 comments on commit c294408

Please sign in to comment.