From 4e87b4c06da7a11751d51119e55d07027156e0b8 Mon Sep 17 00:00:00 2001 From: Oliver Koenig Date: Fri, 17 Jan 2025 08:51:07 -0800 Subject: [PATCH] ADLR/megatron-lm!2534 - refactor: Make `get_mlp_module_spec` public --- .../modelopt_support/gpt/model_specs.py | 8 ++-- megatron/core/models/gpt/gpt_layer_specs.py | 25 ++++++++++- megatron/core/models/multimodal/llava_spec.py | 14 ++++--- .../core/models/vision/vit_layer_specs.py | 2 +- tests/unit_tests/models/test_gpt_model.py | 41 ++++++++++++++++++- .../models/test_multimodal_projector.py | 4 +- 6 files changed, 78 insertions(+), 16 deletions(-) diff --git a/megatron/core/inference/modelopt_support/gpt/model_specs.py b/megatron/core/inference/modelopt_support/gpt/model_specs.py index 4d422bc2f3..91eada8c42 100644 --- a/megatron/core/inference/modelopt_support/gpt/model_specs.py +++ b/megatron/core/inference/modelopt_support/gpt/model_specs.py @@ -1,8 +1,10 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Optional + from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec +from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType @@ -13,7 +15,7 @@ # Use this spec for ModelOpt PTQ and TensorRT-LLM export def get_gpt_layer_modelopt_spec( - num_experts: int = None, + num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, remap_te_layernorm: bool = False, qk_layernorm: bool = False, @@ -24,7 +26,7 @@ def get_gpt_layer_modelopt_spec( is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex has stopped supporting RMSNorm needed by llama. """ - mlp = _get_mlp_module_spec( + mlp = get_mlp_module_spec( use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False ) sharded_state_dict_keys_map = {} diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index d0e48c190c..fef7549b80 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -79,7 +79,7 @@ def get_gpt_layer_with_transformer_engine_spec( ' and will be removed soon. Please update your code accordingly.' ) - mlp = _get_mlp_module_spec( + mlp = get_mlp_module_spec( use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, @@ -169,7 +169,7 @@ def get_gpt_layer_local_spec( ' and will be removed soon. Please update your code accordingly.' ) - mlp = _get_mlp_module_spec( + mlp = get_mlp_module_spec( use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, @@ -236,6 +236,27 @@ def _get_mlp_module_spec( moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, # pylint: disable=unused-arguments moe_use_legacy_grouped_gemm: Optional[bool] = False, +): + warnings.warn( + """This private function is on a deprecation track. Please switch to `get_mlp_module_spec` + since it will be removed in a future release.""" + ) + + return get_mlp_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + fp8=fp8, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) + + +def get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: """Helper function to get module spec for MLP/MoE""" if fp8 is not None: diff --git a/megatron/core/models/multimodal/llava_spec.py b/megatron/core/models/multimodal/llava_spec.py index 09831c6e25..7ffb162a0e 100644 --- a/megatron/core/models/multimodal/llava_spec.py +++ b/megatron/core/models/multimodal/llava_spec.py @@ -1,4 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import Optional + from megatron.core.extensions.transformer_engine import ( TEDotProductAttention, TELayerNormColumnParallelLinear, @@ -6,7 +8,7 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec +from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.dot_product_attention import DotProductAttention @@ -27,15 +29,15 @@ from megatron.core.transformer.torch_norm import WrappedTorchNorm - warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + warnings.warn('Apex is not installed. Falling back to Torch Norm') LNImpl = WrappedTorchNorm def decoder_model_with_transformer_engine_default_spec( - num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False + num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False ) -> ModuleSpec: """LLava decoder TE spec (uses Transformer Engine components).""" - mlp = _get_mlp_module_spec( + mlp = get_mlp_module_spec( use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) return ModuleSpec( @@ -60,10 +62,10 @@ def decoder_model_with_transformer_engine_default_spec( def decoder_model_with_local_default_spec( - num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False + num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False ) -> ModuleSpec: """LLava decoder local spec.""" - mlp = _get_mlp_module_spec( + mlp = get_mlp_module_spec( use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) return ModuleSpec( diff --git a/megatron/core/models/vision/vit_layer_specs.py b/megatron/core/models/vision/vit_layer_specs.py index 5b39efe79f..003293513d 100644 --- a/megatron/core/models/vision/vit_layer_specs.py +++ b/megatron/core/models/vision/vit_layer_specs.py @@ -27,7 +27,7 @@ from megatron.core.transformer.torch_norm import WrappedTorchNorm - warnings.warn(f'Apex is not installed. Falling back to Torch Norm') + warnings.warn('Apex is not installed. Falling back to Torch Norm') LNImpl = WrappedTorchNorm diff --git a/tests/unit_tests/models/test_gpt_model.py b/tests/unit_tests/models/test_gpt_model.py index 4894c8efe8..da756dfb64 100644 --- a/tests/unit_tests/models/test_gpt_model.py +++ b/tests/unit_tests/models/test_gpt_model.py @@ -1,11 +1,15 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import inspect import os import pytest import torch -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, + get_mlp_module_spec, +) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig @@ -59,7 +63,7 @@ def test_set_input_tensor(self): @pytest.mark.internal def test_post_process_forward(self): - config: TransformerConfig = self.gpt_model.config + _ = self.gpt_model.config sequence_length = self.gpt_model.max_sequence_length micro_batch_size = 2 @@ -79,3 +83,36 @@ def test_post_process_forward(self): assert logits.shape[0] == micro_batch_size assert logits.shape[1] == sequence_length assert logits.shape[2] == self.gpt_model.vocab_size + + +def test_get_mlp_module_spec_interface(): + # Get the function signature + sig = inspect.signature(get_mlp_module_spec) + + # Define the expected signature + expected_params = { + "use_te": inspect.Parameter.POSITIONAL_OR_KEYWORD, + "num_experts": inspect.Parameter.POSITIONAL_OR_KEYWORD, + "moe_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD, + "fp8": inspect.Parameter.POSITIONAL_OR_KEYWORD, + "moe_use_legacy_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD, + } + + expected_defaults = { + "use_te": True, + "num_experts": None, + "moe_grouped_gemm": False, + "fp8": None, + "moe_use_legacy_grouped_gemm": False, + } + + # Check parameter kinds + for param_name, param in sig.parameters.items(): + assert param_name in expected_params.keys(), f"Unexpected parameter: {param_name}" + assert param.kind is expected_params[param_name], f"Wrong kind for parameter: {param_name}" + + # Check default values + defaults = { + k: v.default for k, v in sig.parameters.items() if v.default is not inspect.Parameter.empty + } + assert defaults == expected_defaults, "Default values do not match the expected ones." diff --git a/tests/unit_tests/models/test_multimodal_projector.py b/tests/unit_tests/models/test_multimodal_projector.py index 976dc489da..33cdbbfe2d 100644 --- a/tests/unit_tests/models/test_multimodal_projector.py +++ b/tests/unit_tests/models/test_multimodal_projector.py @@ -3,7 +3,7 @@ import pytest import torch -from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec +from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.tensor_parallel.layers import ColumnParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -20,7 +20,7 @@ def setup_method(self, method): transformer_config = TransformerConfig( num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True ) - mlp_layer_spec = _get_mlp_module_spec().submodules + mlp_layer_spec = get_mlp_module_spec().submodules affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None) self.mlp = MultimodalProjector(