Skip to content

Commit

Permalink
Merge branch 'xren/context_parallelism' into 'main'
Browse files Browse the repository at this point in the history
create processor group for context parallelism

See merge request ADLR/megatron-lm!714
  • Loading branch information
jaredcasper committed Oct 6, 2023
2 parents 1254ee8 + 8879205 commit 4c0daab
Showing 1 changed file with 204 additions and 41 deletions.
245 changes: 204 additions & 41 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None

# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None

# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None

# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None

Expand All @@ -60,6 +74,7 @@ def initialize_model_parallel(
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
) -> None:
"""Initialize model data parallel groups.
Expand Down Expand Up @@ -105,6 +120,30 @@ def initialize_model_parallel(
within each data-parallel process group, which specifies
the SHARP application target groups.
context_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
Expand All @@ -126,19 +165,23 @@ def initialize_model_parallel(
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()

if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
if (
world_size
% (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
!= 0
):
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) "
f"x context_parallel_size ({context_parallel_size})"
)

data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
)

num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size

if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
Expand All @@ -160,20 +203,33 @@ def initialize_model_parallel(
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
all_data_parallel_group_ranks = []
all_data_parallel_group_ranks_with_cp = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
for j in range(context_parallel_size * tensor_model_parallel_size):
ranks = range(
start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size
)
group = torch.distributed.new_group(ranks)
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
for j in range(tensor_model_parallel_size):
ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp))
group_with_cp = torch.distributed.new_group(ranks_with_cp)
group_with_cp_gloo = torch.distributed.new_group(ranks_with_cp, backend="gloo")
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp

# Apply SHARP to DP process groups
if use_sharp:
Expand All @@ -189,18 +245,40 @@ def initialize_model_parallel(
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch.distributed.barrier(
group=get_data_parallel_group(), device_ids=[torch.cuda.current_device()]
group=get_data_parallel_group(with_context_parallel=context_parallel_size > 1),
device_ids=[torch.cuda.current_device()],
)
# Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
os.environ["NCCL_SHARP_DISABLE"] = "1"

# Build the context-parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GLOBAL_RANKS
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
for i in range(pipeline_model_parallel_size):
for j in range(data_parallel_size):
start_rank = (
i * num_pipeline_model_parallel_groups
+ j * tensor_model_parallel_size * context_parallel_size
)
end_rank = (
i * num_pipeline_model_parallel_groups
+ (j + 1) * tensor_model_parallel_size * context_parallel_size
)
for k in range(tensor_model_parallel_size):
ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks

# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for i in range(data_parallel_size):
for i in range(data_parallel_size * context_parallel_size):
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
data_parallel_group_ranks_with_cp[i]
for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp
]
group = torch.distributed.new_group(ranks)
if rank in ranks:
Expand Down Expand Up @@ -268,18 +346,33 @@ def initialize_model_parallel(

# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is None
), 'Tensor + data parallel group is already initialized'
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
for i in range(num_tensor_and_data_groups):
start_rank = i * tensor_and_data_group_size
end_rank = (i + 1) * tensor_and_data_group_size
tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * data_parallel_size * context_parallel_size
num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp
for i in range(num_tensor_and_data_groups_with_cp):
start_rank = i * tensor_and_data_group_size_with_cp
end_rank = start_rank + tensor_and_data_group_size_with_cp
ranks = range(start_rank, end_rank)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group

for j in range(context_parallel_size):
ranks = []
for k in range(data_parallel_size):
start_rank = (
i * tensor_and_data_group_size_with_cp
+ j * tensor_model_parallel_size
+ k * tensor_model_parallel_size * context_parallel_size
)
end_rank = start_rank + tensor_model_parallel_size
ranks = ranks + list(range(start_rank, end_rank))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group

# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
Expand Down Expand Up @@ -327,16 +420,40 @@ def get_pipeline_model_parallel_group():
return _PIPELINE_MODEL_PARALLEL_GROUP


def get_data_parallel_group():
def get_data_parallel_group(with_context_parallel=False):
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'data parallel group with context parallel combined is not initialized'
return _DATA_PARALLEL_GROUP_WITH_CP
else:
assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP


def get_data_parallel_group_gloo():
def get_data_parallel_group_gloo(with_context_parallel=False):
"""Get the data parallel group-gloo the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
return _DATA_PARALLEL_GROUP_GLOO
if with_context_parallel:
assert (
_DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
), 'data parallel group-gloo with context parallel combined is not initialized'
return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
else:
assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
return _DATA_PARALLEL_GROUP_GLOO


def get_context_parallel_group():
"""Get the context parallel group the caller rank belongs to."""
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_global_ranks():
"""Get all global ranks of the context parallel group that the caller rank belongs to."""
assert _CONTEXT_PARALLEL_GLOBAL_RANKS is not None, 'context parallel group is not initialized'
return _CONTEXT_PARALLEL_GLOBAL_RANKS


def get_embedding_group():
Expand All @@ -351,20 +468,32 @@ def get_position_embedding_group():
return _POSITION_EMBEDDING_GROUP


def get_amax_reduction_group():
def get_amax_reduction_group(with_context_parallel=False):
"""Get the FP8 amax reduction group the caller rank belongs to."""
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP
if with_context_parallel:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP


def get_tensor_and_data_parallel_group():
def get_tensor_and_data_parallel_group(with_context_parallel=False):
"""Get the tensor and data parallel group the caller rank belongs to."""
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP
if with_context_parallel:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else:
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP


def set_tensor_model_parallel_world_size(world_size):
Expand Down Expand Up @@ -552,11 +681,17 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size


def get_data_parallel_src_rank():
def get_data_parallel_src_rank(with_context_parallel=False):
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
if with_context_parallel:
assert (
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
), "Data parallel group with context parallel combined is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
else:
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_first_rank():
Expand Down Expand Up @@ -590,18 +725,38 @@ def get_pipeline_model_parallel_prev_rank():
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def get_data_parallel_world_size():
def get_data_parallel_world_size(with_context_parallel=False):
"""Return world size for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_data_parallel_group())
return torch.distributed.get_world_size(
group=get_data_parallel_group(with_context_parallel=with_context_parallel)
)
else:
return 0


def get_data_parallel_rank():
def get_data_parallel_rank(with_context_parallel=False):
"""Return my rank for the data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_data_parallel_group())
return torch.distributed.get_rank(
group=get_data_parallel_group(with_context_parallel=with_context_parallel)
)
else:
return 0


def get_context_parallel_world_size():
"""Return world size for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_context_parallel_group())
else:
return 0


def get_context_parallel_rank():
"""Return my rank for the context parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank(group=get_context_parallel_group())
else:
return 0

Expand Down Expand Up @@ -635,12 +790,20 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
Expand Down

0 comments on commit 4c0daab

Please sign in to comment.