Skip to content

Commit

Permalink
dont use model_spec arg + assert changes
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <[email protected]>
  • Loading branch information
aklife97 committed Oct 4, 2023
1 parent 9992794 commit 7ab6a29
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
11 changes: 9 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,19 @@ def validate_args(args, defaults={}):
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')

# MoE Spec check
if args.num_experts is not None:
assert args.model_spec is None, "Model Spec must be None when using MoEs"

# Expert parallelism check
if args.expert_parallel and args.tensor_model_parallel_size > 1:
if args.expert_parallel:
assert args.num_experts is not None, "num_experts must be non None to use expert-parallel"
assert args.num_experts % args.data_parallel_size == 0, \
"Number of experts should be a multiple of data parallel_size."
args.sequence_parallel = True
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel, \
"When using expert parallelism and tensor parallelism, sequence parallelism must be used."

# Print arguments.
_print_args("arguments", args)
Expand Down
7 changes: 5 additions & 2 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TransformerConfig(ModelParallelConfig):
activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
num_moe_experts (int): Number of experts to use for Mixture of Experts.
When >1, it replaces MLP with Switch MLP. Defaults to 1 (no MoE).
When set, it replaces MLP with Switch MLP. Defaults to None (no MoE).
# initialization
init_method (Callable): Method to initialize weights. Note that bias is always set to
Expand Down Expand Up @@ -147,7 +147,7 @@ class TransformerConfig(ModelParallelConfig):
add_bias_linear: bool = True
gated_linear_unit: bool = False
activation_func: Callable = F.gelu
num_moe_experts: int = 1
num_moe_experts: int = None

# initialization
init_method: Callable = None
Expand Down Expand Up @@ -217,6 +217,9 @@ def __post_init__(self):
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True

if self.expert_parallel and self.num_moe_experts is None:
raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')

if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
Expand Down
10 changes: 8 additions & 2 deletions pretrain_gpt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_layer_specs import (
gpt_layer_with_transformer_engine_spec,
gpt_layer_with_transformer_engine_spec_moe
)
from megatron.core.transformer.spec_utils import import_module
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.training import pretrain
Expand All @@ -31,7 +34,10 @@ def model_provider(pre_process=True, post_process=True):
if args.model_spec is not None:
transformer_layer_spec = import_module(args.model_spec)
else:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec
if args.num_experts is None:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec
else:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec_moe

print_rank_0('building GPT model ...')
model = GPTModel(
Expand Down

0 comments on commit 7ab6a29

Please sign in to comment.