From 645c329d07b906464b33aad310ab9fb2b829ac09 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 15 Nov 2024 02:35:27 -0800 Subject: [PATCH] ADLR/megatron-lm!2279 - Add hierarchical cp comm group Co-authored-by: root Co-authored-by: root Co-authored-by: root Co-authored-by: root --- .../core/extensions/transformer_engine.py | 10 ++++ megatron/core/model_parallel_config.py | 8 +++ megatron/core/parallel_state.py | 53 ++++++++++++++++++ .../core/transformer/transformer_config.py | 5 +- megatron/training/arguments.py | 24 ++++++++ megatron/training/initialize.py | 1 + tests/functional_tests/jet_recipes/gpt.yaml | 2 + .../golden_values_dev.json | 1 + .../golden_values_lts.json | 1 + .../model_config.yaml | 54 ++++++++++++++++++ .../model_config.yaml | 55 +++++++++++++++++++ 11 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_dev.json create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_lts.json create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 7ca2cdeea5..449f0b7580 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -19,6 +19,7 @@ from megatron.core.parallel_state import ( get_context_parallel_global_ranks, get_context_parallel_group, + get_hierarchical_context_parallel_groups, get_tensor_and_expert_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, @@ -593,6 +594,15 @@ def __init__( if is_te_min_version("1.10.0"): if cp_comm_type is None: extra_kwargs["cp_comm_type"] = "p2p" + elif cp_comm_type == "a2a+p2p": + assert is_te_min_version("1.12.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support" + "hierarchical cp commucation." + ) + extra_kwargs["cp_comm_type"] = "a2a+p2p" + extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups( + check_initialized=False + ) else: extra_kwargs["cp_comm_type"] = cp_comm_type else: diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 936ac1edf7..ceca67c354 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -39,6 +39,14 @@ class ModelParallelConfig: context_parallel_size: int = 1 """Splits network input along sequence dimension across GPU ranks.""" + hierarchical_context_parallel_sizes: list[int] = None + """Degrees of the hierarchical context parallelism. Users should provide a list to specify + the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains + groups of two levels, so the first value of the list indicates the group size of the a2a + communication type, and the second value indicates the group size of the p2p communication + type. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index c2f47b0c61..d31efd9219 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -79,6 +79,8 @@ # A list of global ranks for each context parallel group to ease calculation of the # destination rank when exchanging KV/dKV between context parallel_ranks _CONTEXT_PARALLEL_GLOBAL_RANKS = None +# Hierarchical context parallel groups +_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = [] # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None @@ -226,6 +228,40 @@ def decompose(index, shape, stride=None): return ranks +def create_hierarchical_parallel_groups( + rank, ranks, group_size, hierarchical_group_sizes, pg_options +): + """Create hierarchical groups for one parallelism. + Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 ... g15. + If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level + of sub-groups, and 4 GPUs in the last level of sub groups. The present function will + create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as: + 8 level-1 sub-groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 level-2 sub-groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 4 level-3 sub-groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] + """ + + hierarchical_groups = [] + accumulated_group_sizes = 1 + processed_group_sizes = 1 + for hierarchical_group_size in hierarchical_group_sizes: + accumulated_group_sizes *= hierarchical_group_size + for k in range(group_size // accumulated_group_sizes): + for j in range(processed_group_sizes): + global_sub_ranks = [ + ranks[j + i * processed_group_sizes + k * accumulated_group_sizes] + for i in range(hierarchical_group_size) + ] + sub_group = torch.distributed.new_group(global_sub_ranks, pg_options=pg_options) + if rank in global_sub_ranks: + hierarchical_groups.append(sub_group) + processed_group_sizes *= hierarchical_group_size + return hierarchical_groups + + class RankGenerator(object): """A class for generating rank groups for different modes of parallelism.""" @@ -356,6 +392,7 @@ def initialize_model_parallel( pipeline_model_parallel_split_rank: Optional[int] = None, use_sharp: bool = False, context_parallel_size: int = 1, + hierarchical_context_parallel_sizes: List[int] = None, expert_model_parallel_size: int = 1, nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: int = 30, @@ -691,6 +728,15 @@ def generator_wrapper(group_type, **kwargs): if rank in ranks: _CONTEXT_PARALLEL_GROUP = group _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks + if hierarchical_context_parallel_sizes: + global _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS + _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS += create_hierarchical_parallel_groups( + rank, + ranks, + context_parallel_size, + hierarchical_context_parallel_sizes, + get_nccl_options('cp', nccl_comm_cfgs), + ) # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP @@ -962,6 +1008,13 @@ def get_context_parallel_global_ranks(check_initialized=True): return _CONTEXT_PARALLEL_GLOBAL_RANKS +def get_hierarchical_context_parallel_groups(check_initialized=True): + """Get the inner ring of context parallel group the caller rank belongs to.""" + if check_initialized: + assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None + return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS + + def get_embedding_group(): """Get the embedding group the caller rank belongs to.""" assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized' diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d22a72d130..28c1830e63 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -311,13 +311,16 @@ class TransformerConfig(ModelParallelConfig): """Inter-gpu communication type for context parallelism. str: all layers share same communication type. List[str]: each layer has its separate communication type. - cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a". + cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be overlapped with attention compute. "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not async, and cannot be overlapped. "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get full sequence of QKV. + "a2a+p2p": A hierarchical implementation of context parallelism to attention. + It uses A2A communications in low-level CP groups (e.g., via NVLink), + and P2P communications in high-level CP groups (e.g., via IBLink). """ #################### diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5791aecb04..650a713fc3 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -199,12 +199,14 @@ def validate_args(args, defaults={}): if args.rank == 0: print('using world size: {}, data-parallel size: {}, ' 'context-parallel size: {}, ' + 'hierarchical context-parallel sizes: {}' 'tensor-model-parallel size: {}, ' 'encoder-tensor-model-parallel size: {}, ' 'pipeline-model-parallel size: {}, ' 'encoder-pipeline-model-parallel size: {}'.format( args.world_size, args.data_parallel_size, args.context_parallel_size, + args.hierarchical_context_parallel_sizes, args.tensor_model_parallel_size, args.encoder_tensor_model_parallel_size, args.pipeline_model_parallel_size, @@ -216,6 +218,13 @@ def validate_args(args, defaults={}): args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size assert args.pipeline_model_parallel_size > 0 + if args.hierarchical_context_parallel_sizes: + from numpy import prod + assert args.context_parallel_size == prod(args.hierarchical_context_parallel_sizes) + if "a2a+p2p" in args.cp_comm_type: + assert args.hierarchical_context_parallel_sizes is not None, \ + "--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm" + # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ 'valid, use --micro-batch-size instead' @@ -727,6 +736,9 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['num_query_groups'] = None kw_args['config_logger_dir'] = args.config_logger_dir + if len(args.cp_comm_type) == 1: + kw_args['cp_comm_type'] = args.cp_comm_type[0] + # Return config. return config_class(**kw_args) @@ -1643,6 +1655,18 @@ def _add_distributed_args(parser): "It is still not in a stable release stage, and may therefore contain bugs or other potential issues.") group.add_argument('--context-parallel-size', type=int, default=1, help='Degree of context parallelism.') + group.add_argument('--cp-comm-type', nargs='+', type=str, default=["p2p"], + help='Inter-gpu communication type for context parallelism: ' + 'p2p, a2a, allgather or a2a+p2p. If a single string is provided, ' + 'all layers will share the same communication type. Users can also ' + 'specify separated types for each layer like ' + '--cp-comm-type p2p p2p a2a a2a a2a+p2p a2a+p2p') + group.add_argument('--hierarchical-context-parallel-sizes', nargs='+', type=int, default=None, + help='Degrees of the hierarchical context parallelism. Users should ' + 'provide a list to specify the sizes for different levels. ' + '--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus ' + 'forms the first level of cp groups and the cp ranks with the same odevity ' + 'forms the second level of cp groups.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 17c25e77d4..f72c1b9eb8 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -282,6 +282,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank, context_parallel_size=args.context_parallel_size, + hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, expert_model_parallel_size=args.expert_model_parallel_size, distributed_timeout_minutes=args.distributed_timeout_minutes, nccl_communicator_config_path=args.nccl_communicator_config_path, diff --git a/tests/functional_tests/jet_recipes/gpt.yaml b/tests/functional_tests/jet_recipes/gpt.yaml index 2d722adeef..3ee2581981 100644 --- a/tests/functional_tests/jet_recipes/gpt.yaml +++ b/tests/functional_tests/jet_recipes/gpt.yaml @@ -107,6 +107,8 @@ products: - gpt3_mr_mcore_te_tp2_pp2_cp2_nondeterministic_dgx_a100_1N8G # cp and attention - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_cp2_nondeterministic_dgx_a100_1N8G # cp and attention - gpt3_mr_mcore_te_tp2_pp2_resume_torch_dist_cp2_nondeterministic_dgx_a100_1N8G # cp and attention + - gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G # cp and attention with a2a+p2p comm type + - gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G # cp and attention with a2a+p2p comm type - environment: [lts, dev] scope: [nightly] platforms: [dgx_a100] diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_dev.json new file mode 100644 index 0000000000..206d78993a --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_dev.json @@ -0,0 +1 @@ +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.82974, 10.85934, 10.88536, 10.78981, 10.64534, 10.56415, 9.99534, 10.13972, 10.06259, 9.71481]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [261.0, 256.0, 258.0, 250.0, 243.0, 265.0, 254.0, 299.0, 299.0, 294.0]}, "iteration_timing_avg": 0.3993126470588235} diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_lts.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_lts.json new file mode 100644 index 0000000000..c0c3ead53e --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/golden_values_lts.json @@ -0,0 +1 @@ +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.85803, 10.88122, 10.85832, 10.80987, 10.66115, 10.55375, 10.01843, 10.14234, 10.05958, 9.71149]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [244.0, 231.0, 243.0, 257.0, 247.0, 267.0, 256.0, 299.0, 318.0, 325.0]}, "iteration_timing_avg": 0.3993126470588235} diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..4af4dd14f1 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,54 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1 + NVTE_FUSED_ATTN: 0 + NVTE_FLASH_ATTN: 1 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 2 + --context-parallel-size: 4 + --cp-comm-type: a2a+p2p + --hierarchical-context-parallel-sizes: 2 2 + --sequence-parallel: true + --hidden-dropout: 0.0 + --attention-dropout: 0.0 + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: regular diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..fef1224040 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_cp4_a2a_p2p_nondeterministic_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,55 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 1 + NVTE_FUSED_ATTN: 0 + NVTE_FLASH_ATTN: 1 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 100 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 50 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 2 + --context-parallel-size: 4 + --cp-comm-type: a2a+p2p + --hierarchical-context-parallel-sizes: 2 2 + --sequence-parallel: true + --hidden-dropout: 0.0 + --attention-dropout: 0.0 + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-checkpoint-opt_param-scheduler: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: ckpt-resume