Skip to content

Commit

Permalink
Merge branch 'decouple_send_recv_issue' into 'main'
Browse files Browse the repository at this point in the history
tunable schedule with overlapping

See merge request ADLR/megatron-lm!2117
  • Loading branch information
jaredcasper committed Nov 1, 2024
2 parents 2e2bdf6 + 4295be1 commit 441cb92
Show file tree
Hide file tree
Showing 10 changed files with 807 additions and 211 deletions.
35 changes: 33 additions & 2 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,27 @@ class ModelParallelConfig:
encoder and decoder (e.g., T5). Ignored if None.
"""

overlap_p2p_comm_warmup_flush: bool = False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""

microbatch_group_size_per_vp_stage: Optional[int] = None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""

###################
# CPU Offloading
###################
Expand Down Expand Up @@ -339,6 +360,16 @@ def __post_init__(self):
if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, sequence parallelism "
"must be used"
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)

if self.microbatch_group_size_per_vp_stage is None:
self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size

if self.overlap_p2p_comm_warmup_flush:
if not self.overlap_p2p_comm or self.batch_p2p_comm:
raise ValueError(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm"
)
52 changes: 27 additions & 25 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import operator
from functools import reduce
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -166,8 +164,7 @@ def _p2p_ops(
prev_pipeline_rank: int,
next_pipeline_rank: int,
):
reqs = []
rank = get_pipeline_model_parallel_rank()
reqs = {}
even_send_odd_recv_group = group
if get_pipeline_model_parallel_world_size() == 2:
# Use the global process group for one of the two p2p communications
Expand All @@ -183,50 +180,50 @@ def _p2p_ops(
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=next_pipeline_rank, group=even_send_odd_recv_group
)
reqs.append(send_next_req)
reqs["send_next"] = send_next_req

if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_recv_odd_send_group
)
reqs.append(recv_prev_req)
reqs["recv_prev"] = recv_prev_req

if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_send_odd_recv_group
)
reqs.append(send_prev_req)
reqs["send_prev"] = send_prev_req

if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=next_pipeline_rank, group=even_recv_odd_send_group
)
reqs.append(recv_next_req)
reqs["recv_next"] = recv_next_req

else:
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_send_odd_recv_group
)
reqs.append(recv_prev_req)
reqs["recv_prev"] = recv_prev_req

if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next, dst=next_pipeline_rank, group=even_recv_odd_send_group
)
reqs.append(send_next_req)
reqs["send_next"] = send_next_req

if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next, src=next_pipeline_rank, group=even_send_odd_recv_group
)
reqs.append(recv_next_req)
reqs["recv_next"] = recv_next_req

if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_recv_odd_send_group
)
reqs.append(send_prev_req)
reqs["send_prev"] = send_prev_req
return reqs


Expand Down Expand Up @@ -349,7 +346,10 @@ def _ring_exchange_wrapper(**kwargs):
assert not isinstance(prev_rank, list)
prev_rank = [prev_rank]

reqs = []
if config.use_ring_exchange_p2p or config.batch_p2p_comm:
reqs = []
else:
reqs = {}
tensor_recv_prev_list = []
tensor_recv_next_list = []

Expand All @@ -366,20 +366,22 @@ def _ring_exchange_wrapper(**kwargs):
else:
tensor_recv_next = None

reqs.extend(
p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=group,
prev_pipeline_rank=pr,
next_pipeline_rank=nr,
)
p2p_reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=group,
prev_pipeline_rank=pr,
next_pipeline_rank=nr,
)
if isinstance(p2p_reqs, list):
reqs.extend(p2p_reqs)
else:
reqs.update(p2p_reqs)

if wait_on_reqs and len(reqs) > 0:
for req in reqs:
for req in reqs if isinstance(reqs, list) else reqs.values():
req.wait()
reqs = None

Expand Down
Loading

0 comments on commit 441cb92

Please sign in to comment.