Skip to content

Commit

Permalink
Merge branch 'add_hierarchical_cp_comm_group' into 'main'
Browse files Browse the repository at this point in the history
Add hierarchical cp comm group

See merge request ADLR/megatron-lm!2279
  • Loading branch information
ericharper committed Nov 15, 2024
2 parents 2163865 + 645c329 commit 2bdc60c
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 1 deletion.
10 changes: 10 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -593,6 +594,15 @@ def __init__(
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
elif cp_comm_type == "a2a+p2p":
assert is_te_min_version("1.12.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs["cp_comm_type"] = "a2a+p2p"
extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups(
check_initialized=False
)
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
else:
Expand Down
8 changes: 8 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class ModelParallelConfig:
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""

hierarchical_context_parallel_sizes: list[int] = None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""

expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

Expand Down
53 changes: 53 additions & 0 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
# Hierarchical context parallel groups
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = []

# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
Expand Down Expand Up @@ -226,6 +228,40 @@ def decompose(index, shape, stride=None):
return ranks


def create_hierarchical_parallel_groups(
rank, ranks, group_size, hierarchical_group_sizes, pg_options
):
"""Create hierarchical groups for one parallelism.
Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 ... g15.
If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level
of sub-groups, and 4 GPUs in the last level of sub groups. The present function will
create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as:
8 level-1 sub-groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 level-2 sub-groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
4 level-3 sub-groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
"""

hierarchical_groups = []
accumulated_group_sizes = 1
processed_group_sizes = 1
for hierarchical_group_size in hierarchical_group_sizes:
accumulated_group_sizes *= hierarchical_group_size
for k in range(group_size // accumulated_group_sizes):
for j in range(processed_group_sizes):
global_sub_ranks = [
ranks[j + i * processed_group_sizes + k * accumulated_group_sizes]
for i in range(hierarchical_group_size)
]
sub_group = torch.distributed.new_group(global_sub_ranks, pg_options=pg_options)
if rank in global_sub_ranks:
hierarchical_groups.append(sub_group)
processed_group_sizes *= hierarchical_group_size
return hierarchical_groups


class RankGenerator(object):
"""A class for generating rank groups for different modes of parallelism."""

Expand Down Expand Up @@ -356,6 +392,7 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank: Optional[int] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: List[int] = None,
expert_model_parallel_size: int = 1,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
Expand Down Expand Up @@ -691,6 +728,15 @@ def generator_wrapper(group_type, **kwargs):
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
if hierarchical_context_parallel_sizes:
global _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS += create_hierarchical_parallel_groups(
rank,
ranks,
context_parallel_size,
hierarchical_context_parallel_sizes,
get_nccl_options('cp', nccl_comm_cfgs),
)

# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
Expand Down Expand Up @@ -962,6 +1008,13 @@ def get_context_parallel_global_ranks(check_initialized=True):
return _CONTEXT_PARALLEL_GLOBAL_RANKS


def get_hierarchical_context_parallel_groups(check_initialized=True):
"""Get the inner ring of context parallel group the caller rank belongs to."""
if check_initialized:
assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None
return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS


def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
Expand Down
5 changes: 4 additions & 1 deletion megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,16 @@ class TransformerConfig(ModelParallelConfig):
"""Inter-gpu communication type for context parallelism.
str: all layers share same communication type.
List[str]: each layer has its separate communication type.
cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a".
cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be
overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention. The all-gather is not
async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get
full sequence of QKV.
"a2a+p2p": A hierarchical implementation of context parallelism to attention.
It uses A2A communications in low-level CP groups (e.g., via NVLink),
and P2P communications in high-level CP groups (e.g., via IBLink).
"""

####################
Expand Down
24 changes: 24 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,14 @@ def validate_args(args, defaults={}):
if args.rank == 0:
print('using world size: {}, data-parallel size: {}, '
'context-parallel size: {}, '
'hierarchical context-parallel sizes: {}'
'tensor-model-parallel size: {}, '
'encoder-tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {}, '
'encoder-pipeline-model-parallel size: {}'.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
args.hierarchical_context_parallel_sizes,
args.tensor_model_parallel_size,
args.encoder_tensor_model_parallel_size,
args.pipeline_model_parallel_size,
Expand All @@ -216,6 +218,13 @@ def validate_args(args, defaults={}):
args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size
assert args.pipeline_model_parallel_size > 0

if args.hierarchical_context_parallel_sizes:
from numpy import prod
assert args.context_parallel_size == prod(args.hierarchical_context_parallel_sizes)
if "a2a+p2p" in args.cp_comm_type:
assert args.hierarchical_context_parallel_sizes is not None, \
"--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm"

# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
Expand Down Expand Up @@ -727,6 +736,9 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['num_query_groups'] = None
kw_args['config_logger_dir'] = args.config_logger_dir

if len(args.cp_comm_type) == 1:
kw_args['cp_comm_type'] = args.cp_comm_type[0]

# Return config.
return config_class(**kw_args)

Expand Down Expand Up @@ -1643,6 +1655,18 @@ def _add_distributed_args(parser):
"It is still not in a stable release stage, and may therefore contain bugs or other potential issues.")
group.add_argument('--context-parallel-size', type=int, default=1,
help='Degree of context parallelism.')
group.add_argument('--cp-comm-type', nargs='+', type=str, default=["p2p"],
help='Inter-gpu communication type for context parallelism: '
'p2p, a2a, allgather or a2a+p2p. If a single string is provided, '
'all layers will share the same communication type. Users can also '
'specify separated types for each layer like '
'--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p')
group.add_argument('--hierarchical-context-parallel-sizes', nargs='+', type=int, default=None,
help='Degrees of the hierarchical context parallelism. Users should '
'provide a list to specify the sizes for different levels. '
'--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus '
'forms the first level of cp groups and the cp ranks with the same odevity '
'forms the second level of cp groups.')
group.add_argument('--nccl-communicator-config-path', type=str, default=None,
help='Path to the yaml file with NCCL communicator '
'configurations. The number of min/max thread groups and thread '
Expand Down
1 change: 1 addition & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
Expand Down
2 changes: 2 additions & 0 deletions tests/functional_tests/jet_recipes/gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ products:
- gpt3_mr_mcore_te_tp2_pp2_cp2_nondeterministic_dgx_a100_1N8G # cp and attention
- gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_cp2_nondeterministic_dgx_a100_1N8G # cp and attention
- gpt3_mr_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic_dgx_a100_1N8G # cp and attention
- gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G # cp and attention with a2a+p2p comm type
- gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G # cp and attention with a2a+p2p comm type
- environment: [lts, dev]
scope: [nightly]
platforms: [dgx_a100]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.82974, 10.85934, 10.88536, 10.78981, 10.64534, 10.56415, 9.99534, 10.13972, 10.06259, 9.71481]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [261.0, 256.0, 258.0, 250.0, 243.0, 265.0, 254.0, 299.0, 299.0, 294.0]}, "iteration_timing_avg": 0.3993126470588235}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85803, 10.88122, 10.85832, 10.80987, 10.66115, 10.55375, 10.01843, 10.14234, 10.05958, 9.71149]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [244.0, 231.0, 243.0, 257.0, 247.0, 267.0, 256.0, 299.0, 318.0, 325.0]}, "iteration_timing_avg": 0.3993126470588235}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
ENV_VARS:
CUDA_DEVICE_MAX_CONNECTIONS: 1
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1
NVTE_FUSED_ATTN: 0
NVTE_FLASH_ATTN: 1
MODEL_ARGS:
--num-layers: 12
--hidden-size: 512
--num-attention-heads: 8
--log-params-norm: true
--log-num-zeros-in-grad: true
--log-validation-ppl-to-tensorboard: true
--log-timers-to-tensorboard: true
--tensorboard-dir: ${TENSORBOARD_PATH}
--micro-batch-size: 4
--global-batch-size: 32
--seq-length: 1024
--max-position-embeddings: 1024
--train-iters: 50
--timing-log-level: 2
--lr-decay-iters: 320000
--save: ${CHECKPOINT_PATH}
--load: ${CHECKPOINT_PATH}
--data-path: ${DATA_PATH}/my-gpt3_00_text_document
--vocab-file: ${DATA_PATH}/bpe/vocab.json
--merge-file: ${DATA_PATH}/bpe/merges.txt
--split: 949,50,1
--distributed-backend: nccl
--lr: 0.00015
--lr-decay-style: cosine
--min-lr: 1.0e-5
--weight-decay: 1e-2
--clip-grad: 1.0
--lr-warmup-fraction: .01
--log-interval: 1
--save-interval: 10000
--eval-interval: 1000
--eval-iters: 10
--transformer-impl: transformer_engine
--tensor-model-parallel-size: 1
--pipeline-model-parallel-size: 2
--context-parallel-size: 4
--cp-comm-type: a2a+p2p
--hierarchical-context-parallel-sizes: 2 2
--sequence-parallel: true
--hidden-dropout: 0.0
--attention-dropout: 0.0
--no-gradient-accumulation-fusion: true
--attention-softmax-in-fp32: true
--use-mcore-models: true
--ckpt-format: torch_dist
--data-cache-path: ${DATA_CACHE_PATH}
--bf16: true
TEST_TYPE: regular
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
ENV_VARS:
CUDA_DEVICE_MAX_CONNECTIONS: 1
NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1
NVTE_FUSED_ATTN: 0
NVTE_FLASH_ATTN: 1
MODEL_ARGS:
--num-layers: 12
--hidden-size: 512
--num-attention-heads: 8
--log-params-norm: true
--log-num-zeros-in-grad: true
--log-validation-ppl-to-tensorboard: true
--log-timers-to-tensorboard: true
--tensorboard-dir: ${TENSORBOARD_PATH}
--micro-batch-size: 4
--global-batch-size: 32
--seq-length: 1024
--max-position-embeddings: 1024
--train-iters: 100
--timing-log-level: 2
--lr-decay-iters: 320000
--save: ${CHECKPOINT_PATH}
--load: ${CHECKPOINT_PATH}
--data-path: ${DATA_PATH}/my-gpt3_00_text_document
--vocab-file: ${DATA_PATH}/bpe/vocab.json
--merge-file: ${DATA_PATH}/bpe/merges.txt
--split: 949,50,1
--distributed-backend: nccl
--lr: 0.00015
--lr-decay-style: cosine
--min-lr: 1.0e-5
--weight-decay: 1e-2
--clip-grad: 1.0
--lr-warmup-fraction: .01
--log-interval: 1
--save-interval: 50
--eval-interval: 1000
--eval-iters: 10
--transformer-impl: transformer_engine
--tensor-model-parallel-size: 1
--pipeline-model-parallel-size: 2
--context-parallel-size: 4
--cp-comm-type: a2a+p2p
--hierarchical-context-parallel-sizes: 2 2
--sequence-parallel: true
--hidden-dropout: 0.0
--attention-dropout: 0.0
--no-gradient-accumulation-fusion: true
--attention-softmax-in-fp32: true
--use-checkpoint-opt_param-scheduler: true
--use-mcore-models: true
--ckpt-format: torch_dist
--data-cache-path: ${DATA_CACHE_PATH}
--bf16: true
TEST_TYPE: ckpt-resume

0 comments on commit 2bdc60c

Please sign in to comment.