Skip to content

Commit

Permalink
Merge branch 'moe_core' into 'main'
Browse files Browse the repository at this point in the history
Mixture of Experts + Expert Parallel support for core

See merge request ADLR/megatron-lm!717
  • Loading branch information
jaredcasper committed Oct 5, 2023
2 parents 5e7a824 + 2e30ced commit 0d609ce
Show file tree
Hide file tree
Showing 20 changed files with 637 additions and 97 deletions.
18 changes: 17 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,19 @@ def validate_args(args, defaults={}):
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')

# MoE Spec check
if args.num_experts is not None:
assert args.model_spec is None, "Model Spec must be None when using MoEs"

# Expert parallelism check
if args.expert_parallel:
assert args.num_experts is not None, "num_experts must be non None to use expert-parallel"
assert args.num_experts % args.data_parallel_size == 0, \
"Number of experts should be a multiple of data parallel_size."
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel, \
"When using expert parallelism and tensor parallelism, sequence parallelism must be used."

# Print arguments.
_print_args("arguments", args)
Expand Down Expand Up @@ -412,6 +425,7 @@ def core_transformer_config_from_args(args):
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
Expand Down Expand Up @@ -841,6 +855,8 @@ def _add_training_args(parser):
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
group.add_argument('--expert-parallel', action='store_true',
help='Enable expert parallel optimization.')
return parser


Expand Down Expand Up @@ -1299,4 +1315,4 @@ def _add_experimental_args(parser):
'layer implementation. For more details, check the'
'`transformer_layer.py` file that details the use '
'of spec based customization.')
return parser
return parser
51 changes: 39 additions & 12 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def ensure_directory_exists(filename):

def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel=None,
tensor_rank=None, pipeline_rank=None):
tensor_rank=None, pipeline_rank=None,
expert_parallel=None):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
Expand All @@ -93,6 +94,11 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False,
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if expert_parallel is None:
args = get_args()
expert_parallel = args.expert_parallel

data_rank = mpu.get_data_parallel_rank()

# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
Expand All @@ -102,7 +108,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False,
f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')

if expert_parallel:
common_path = common_path + f'_{data_rank:03d}'

return os.path.join(common_path, "model_optim_rng.pt")

Expand All @@ -114,24 +123,42 @@ def get_distributed_optimizer_checkpoint_name(model_checkpoint_name):

def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
pipeline parallelism/expert parallelism or not.
Since the checkpoint naming scheme changes if pipeline parallelism
is present, we need to look for both naming schemes if we don't
know if the checkpoint has pipeline parallelism.
Since the checkpoint naming scheme changes if pipeline or expert
parallelism is present, we need to look for both naming schemes if
we don't know if the checkpoint has pipeline or expert parallelism.
"""

# Look for checkpoint with no pipelining
# Look for checkpoint with no pipelining and no expert parallelism
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0,
expert_parallel=False)
if os.path.isfile(filename):
return filename

# Look for checkpoint with no pipelining and expert parallelism
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
tensor_rank=0, pipeline_rank=0,
expert_parallel=True)
if os.path.isfile(filename):
return filename

# Look for checkpoint with pipelining
# Look for checkpoint with pipelining and no expert parallelism
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
tensor_rank=0, pipeline_rank=0,
expert_parallel=False)
if os.path.isfile(filename):
return filename

# Look for checkpoint with pipelining and expert parallelism
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0,
expert_parallel=True)
if os.path.isfile(filename):
return filename

Expand Down Expand Up @@ -237,7 +264,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):

# Collect args, model, RNG.
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
or mpu.get_data_parallel_rank() == 0 \
or args.expert_parallel:

# Arguments, iteration, and model.
state_dict = {}
Expand Down Expand Up @@ -606,7 +634,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:

rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/fusions/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,15 @@ def __init__(
persist_layer_norm=True,
sequence_parallel=False,
zero_centered_gamma=False,
normalization="LayerNorm",
):
super().__init__()

self.zero_centered_gamma = zero_centered_gamma
self.normalization = normalization
assert normalization == "LayerNorm", '({}) is not supported in ' 'FusedLayerNorm'.format(
normalization
)

# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class ModelParallelConfig:
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
expert_parallel (bool): Distributes Moe Experts across data parallel dimension. Defaults to False.
Initialization
--------------
Expand Down Expand Up @@ -115,6 +117,7 @@ class ModelParallelConfig:
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: Optional[int] = None
sequence_parallel: bool = False
expert_parallel: bool = False

# Initialization
perform_initialization: bool = True
Expand Down Expand Up @@ -165,3 +168,9 @@ def __post_init__(self):

if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype

if self.expert_parallel and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, sequence parallelism must be used"
)
52 changes: 52 additions & 0 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.switch_mlp import SwitchMLP
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
Expand Down Expand Up @@ -62,3 +63,54 @@
mlp_bda=get_bias_dropout_add,
),
)

# Use this spec to use lower level Transformer Engine modules and SwitchMLP based MoE
gpt_layer_with_transformer_engine_spec_moe = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
dot_product_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=SwitchMLP, # MOE
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)

# Use this spec for an implementation using only modules in megatron core for MoE models
gpt_layer_local_spec_moe = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
dot_product_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=SwitchMLP, # MOE
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
),
)
56 changes: 31 additions & 25 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP = None
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP = None

_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
Expand Down Expand Up @@ -58,7 +59,6 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
use_fp8: bool = False,
use_sharp: bool = False,
) -> None:
"""Initialize model data parallel groups.
Expand Down Expand Up @@ -99,11 +99,6 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_fp8 (bool, default = False):
Construct GPU groups needed for FP8 training, namely for
amax reduction across the product of the data-parallel and
tensor-parallel groups.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
Expand Down Expand Up @@ -271,19 +266,20 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks

# Build the FP8 groups.
global _AMAX_REDUCTION_GROUP
assert _AMAX_REDUCTION_GROUP is None, 'FP8 amax reduction group is already initialized'
if use_fp8:
amax_group_size: int = tensor_model_parallel_size * data_parallel_size
num_amax_groups: int = world_size // amax_group_size
for i in range(num_amax_groups):
start_rank = i * amax_group_size
end_rank = (i + 1) * amax_group_size
ranks = range(start_rank, end_rank)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_AMAX_REDUCTION_GROUP = group
# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is None
), 'Tensor + data parallel group is already initialized'
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
for i in range(num_tensor_and_data_groups):
start_rank = i * tensor_and_data_group_size
end_rank = (i + 1) * tensor_and_data_group_size
ranks = range(start_rank, end_rank)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group

# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
Expand Down Expand Up @@ -357,8 +353,18 @@ def get_position_embedding_group():

def get_amax_reduction_group():
"""Get the FP8 amax reduction group the caller rank belongs to."""
assert _AMAX_REDUCTION_GROUP is not None, 'FP8 amax reduction group is not initialized'
return _AMAX_REDUCTION_GROUP
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'FP8 amax reduction group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP


def get_tensor_and_data_parallel_group():
"""Get the tensor and data parallel group the caller rank belongs to."""
assert (
_TENSOR_AND_DATA_PARALLEL_GROUP is not None
), 'tensor and data parallel group is not initialized'
return _TENSOR_AND_DATA_PARALLEL_GROUP


def set_tensor_model_parallel_world_size(world_size):
Expand Down Expand Up @@ -633,8 +639,8 @@ def destroy_model_parallel():
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _AMAX_REDUCTION_GROUP
_AMAX_REDUCTION_GROUP = None
global _TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
Expand Down
4 changes: 4 additions & 0 deletions megatron/core/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_sequence_parallel_region_to_moe,
gather_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region_from_moe,
scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
Expand Down Expand Up @@ -53,4 +55,6 @@
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
"gather_from_sequence_parallel_region_to_moe",
"reduce_scatter_to_sequence_parallel_region_from_moe",
]
Loading

0 comments on commit 0d609ce

Please sign in to comment.