diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index a1c34291e..a25012709 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -710,10 +710,10 @@ class MambaNumHeadsHp(TracedHp): Need special handling for active_slice property to trim heads within each group. """ - def __init__( - self, choices: Sequence[HPType], original: HPType | None = None, ngroups: int = 1 - ) -> None: - super().__init__(choices, original) + def __init__(self, nheads: int, ngroups: int = 1) -> None: + """Initialize choices as multiples of ngroups.""" + choices = [h * ngroups for h in range(1, nheads // ngroups + 1)] + super().__init__(choices) self._ngroups = ngroups @property @@ -892,7 +892,7 @@ def _setup(self, *, hidden_size: TracedHp): assert self.d_inner == self.nheads * self.headdim, "d_inner must be nheads * headdim" # Register hyperparameters for Mamba heads and head dimensions - mamba_num_heads = MambaNumHeadsHp(list(range(1, self.nheads + 1)), ngroups=self.ngroups) + mamba_num_heads = MambaNumHeadsHp(self.nheads, self.ngroups) mamba_head_dim = TracedHp(list(range(1, self.headdim + 1))) d_inner = MambaDInnerHp(mamba_num_heads, mamba_head_dim) bc = TracedHp([2 * self.ngroups * self.d_state]) # not configurable @@ -967,19 +967,14 @@ def _setup(self, *, hidden_size: TracedHp): def modify( self, *, - mamba_num_heads_divisor: int = 1, mamba_head_dim_divisor: int = 1, **kwargs, # Unused hparams ) -> None: """Modify Mamba hyperparameters.""" # Modify MambaMixer hparams - for hp_name, divisor in [ - ("mamba_num_heads", mamba_num_heads_divisor), - ("mamba_head_dim", mamba_head_dim_divisor), - ]: - hp = self.mixer.get_hparam(hp_name) - choices = {int(make_divisible(c, divisor)) for c in hp.choices} - hp.choices = list(set(hp.choices) & choices | {hp.original}) + hp = self.mixer.get_hparam("mamba_head_dim") + choices = {int(make_divisible(c, mamba_head_dim_divisor)) for c in hp.choices} + hp.choices = list(set(hp.choices) & choices | {hp.original}) def export(self): """Export the dynamic module to a torch.nn.Module.""" diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index eb6f9051e..db6769b7b 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -98,6 +98,7 @@ "MCoreMinitronModeDescriptor", "MCoreMinitronSearcher", "drop_mcore_language_model_layers", + "get_mcore_minitron_config", ] @@ -310,7 +311,6 @@ def run_search(self) -> None: "megatron.core.models.mamba.MambaModel": { "hidden_size_divisor": 64, "ffn_hidden_size_divisor": 64, - "mamba_num_heads_divisor": 4, "mamba_head_dim_divisor": 4, "num_moe_experts_divisor": 1, } @@ -324,6 +324,29 @@ def run_search(self) -> None: ) +def get_mcore_minitron_config( + channel_divisor: int = 64, + mamba_head_dim_divisor: int = 4, + num_moe_experts_divisor: int = 1, +) -> ModeloptBaseConfig: + """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" + config = MCoreMinitronConfig() + + def _set_divisors(c): + for k, v in c.items(): + if isinstance(v, dict): + _set_divisors(v) + elif k in ["hidden_size_divisor", "ffn_hidden_size_divisor"]: + c[k] = channel_divisor + elif k == "mamba_head_dim_divisor": + c[k] = mamba_head_dim_divisor + elif k == "num_moe_experts_divisor": + c[k] = num_moe_experts_divisor + + _set_divisors(config) + return config + + def _convert_model_to_dynamic_space( model: nn.Module, config: ModeloptBaseConfig | None = None ) -> DynamicSpace: diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py new file mode 100644 index 000000000..856edd38c --- /dev/null +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import modelopt.torch.prune as mtp + + +def prune_minitron(model, export_config, config, channel_divisor=64): + return mtp.prune( + model, + mode=[("mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config(channel_divisor))], + constraints={"export_config": export_config}, + dummy_input=None, # Not used + config=config, + ) diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch/sparsity/sparse_attention_common.py similarity index 100% rename from tests/_test_utils/torch_sparsity/sparse_attention_common.py rename to tests/_test_utils/torch/sparsity/sparse_attention_common.py diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index b675b73f0..2679d3090 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -48,6 +48,7 @@ expand_head_indices, ) from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size +from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_minitron_config from modelopt.torch.utils.random import centroid SEED = 1234 @@ -56,13 +57,13 @@ def _test_gpt_search_space( num_attention_heads, num_query_groups, activation_func, normalization, rank, size ): - channel_divisor = 64 + channel_divisor = 4 num_layers = min(size * 2, 8) - hidden_size = 256 - ffn_hidden_size = 128 - max_sequence_length = 16 - vocab_size = 64 + hidden_size = channel_divisor * 4 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 32 batch_size = 2 model = get_mcore_gpt_model( @@ -80,7 +81,7 @@ def _test_gpt_search_space( normalization=normalization, ).cuda() - model = mtn.convert(model, "mcore_minitron") + model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) assert isinstance(model, _DynamicMCoreLanguageModel) for m in model.modules(): @@ -153,17 +154,17 @@ def test_expand_head_indices(): def _test_gpt_moe_search_space(rank, size): - channel_divisor = 64 + channel_divisor = 4 num_layers = min(size * 2, 8) - hidden_size = 256 + hidden_size = channel_divisor * 4 num_attention_heads = 8 num_query_groups = 4 - moe_ffn_hidden_size = 128 + moe_ffn_hidden_size = channel_divisor * 2 num_moe_experts = 4 - moe_shared_expert_intermediate_size = 256 - max_sequence_length = 16 - vocab_size = 64 + moe_shared_expert_intermediate_size = channel_divisor * 4 + max_sequence_length = 8 + vocab_size = 32 batch_size = 2 model = get_mcore_gpt_model( @@ -182,7 +183,7 @@ def _test_gpt_moe_search_space(rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, ).cuda() - model = mtn.convert(model, "mcore_minitron") + model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) moe = model.decoder.layers[0].mlp assert isinstance(moe, _DynamicMoELayer) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 1f6f44eb2..430b5e261 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -40,23 +40,23 @@ ) from modelopt.torch.nas.traced_hp import TracedHp from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size +from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_minitron_config from modelopt.torch.utils.random import centroid SEED = 1234 def _test_mamba_search_space(rank, size): - channel_divisor = 64 - mamba_num_heads_divisor = 4 + channel_divisor = 4 mamba_head_dim_divisor = 4 num_layers = size hybrid_override_pattern = "M" * size - hidden_size = 256 - mamba_state_dim = 64 - mamba_head_dim = 16 + hidden_size = channel_divisor * 4 + mamba_state_dim = channel_divisor + mamba_head_dim = mamba_head_dim_divisor * 2 mamba_num_groups = 2 - max_sequence_length = 16 + max_sequence_length = 8 vocab_size = 32 batch_size = 2 @@ -75,7 +75,7 @@ def _test_mamba_search_space(rank, size): ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads - model = mtn.convert(model, "mcore_minitron") + model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) assert isinstance(model, _DynamicMCoreLanguageModel) if is_pipeline_first_stage(): @@ -94,7 +94,7 @@ def _test_mamba_search_space(rank, size): # NOTE: `search_space_size` does not reduce across TP/PP groups ss_size_per_pp = search_space_size(model) - num_heads_choices = mamba_num_heads // mamba_num_heads_divisor + num_heads_choices = mamba_num_heads // mamba_num_groups head_dim_choices = mamba_head_dim // mamba_head_dim_divisor hidden_size_choices = hidden_size // channel_divisor num_layers_per_pp = num_layers // size @@ -125,7 +125,7 @@ def test_mamba_search_space(): def test_mamba_num_heads_hp(): - num_heads = MambaNumHeadsHp([2, 4, 6, 8], ngroups=2) # 4 heads per group + num_heads = MambaNumHeadsHp(8, ngroups=2) # 4 heads per group assert num_heads.choices == [2, 4, 6, 8] assert num_heads.active_slice == slice(8) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 1afcdadb3..a0d4877bb 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -28,11 +28,11 @@ run_mcore_inference_with_dummy_input, ) from _test_utils.torch.misc import compare_outputs, set_seed +from _test_utils.torch.nas_prune.minitron_common import prune_minitron from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn -import modelopt.torch.prune as mtp from modelopt.torch.nas.conversion import export_searchspace from modelopt.torch.nas.plugins.megatron import ( NumAttentionHeadsHp, @@ -44,19 +44,23 @@ from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, _convert_model_to_dynamic_space, + get_mcore_minitron_config, ) SEED = 1234 def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): + # Use relatively bigger model here for more accurate test for sorting + channel_divisor = 64 + num_layers = size - hidden_size = 128 + hidden_size = channel_divisor * 2 num_attention_heads = 8 num_query_groups = 4 - ffn_hidden_size = 64 + ffn_hidden_size = channel_divisor * 2 max_sequence_length = 32 - vocab_size = 128 + vocab_size = channel_divisor * 2 batch_size = 2 model = get_mcore_gpt_model( @@ -80,7 +84,9 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): m.weight.data = torch.randn_like(m.weight) model.eval() - dynamic_space = _convert_model_to_dynamic_space(model) + dynamic_space = _convert_model_to_dynamic_space( + model, get_mcore_minitron_config(channel_divisor) + ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks # Compute activations for sorting @@ -109,7 +115,7 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): @pytest.mark.parametrize("activation_func", ["swiglu"]) -def test_mcore_gpt_parameter_sorting(activation_func, need_2_gpus): +def test_mcore_gpt_parameter_sorting(activation_func): set_seed(SEED) spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -202,10 +208,12 @@ def _test_mcore_gpt_pruning( rank, size, ): - hidden_size = 256 - ffn_hidden_size = 256 - max_sequence_length = 16 - vocab_size = 64 + channel_divisor = 4 + + hidden_size = channel_divisor * 4 + ffn_hidden_size = channel_divisor * 4 + max_sequence_length = 8 + vocab_size = 16 batch_size = 2 num_layers = min(size * 2, 8) @@ -275,13 +283,7 @@ def forward_loop(m): assert ckpt_path is None else: config["forward_loop"] = forward_loop - model, pruning_scores = mtp.prune( - model, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config=config, - ) + model, pruning_scores = prune_minitron(model, export_config, config, channel_divisor) if not skip_sorting: assert pruning_scores["layer_scores"] assert pruning_scores["activations_per_rank"] @@ -317,12 +319,8 @@ def forward_loop(m): if ckpt_path: model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) - mtp.prune( - model_rerun, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config={"scores_path": ckpt_path}, + model_rerun, pruning_scores = prune_minitron( + model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor ) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) @@ -394,13 +392,16 @@ def test_mcore_gpt_pruning( def _test_mcore_gpt_moe_parameter_sorting(rank, size): + # Use relatively bigger model here for more accurate test for sorting + channel_divisor = 64 + num_layers = min(size * 2, 8) - hidden_size = 256 + hidden_size = channel_divisor * 4 num_attention_heads = 8 num_query_groups = 4 - moe_ffn_hidden_size = 128 + moe_ffn_hidden_size = channel_divisor * 2 num_moe_experts = 4 - moe_shared_expert_intermediate_size = 256 + moe_shared_expert_intermediate_size = channel_divisor * 4 max_sequence_length = 16 vocab_size = 64 batch_size = 2 @@ -428,7 +429,9 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): m.weight.data = torch.randn_like(m.weight) model.eval() - dynamic_space = _convert_model_to_dynamic_space(model) + dynamic_space = _convert_model_to_dynamic_space( + model, get_mcore_minitron_config(channel_divisor) + ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks # Compute activations for sorting @@ -459,7 +462,7 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) -def test_mcore_gpt_moe_parameter_sorting(need_2_gpus): +def test_mcore_gpt_moe_parameter_sorting(): set_seed(SEED) spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -469,13 +472,15 @@ def test_mcore_gpt_moe_parameter_sorting(need_2_gpus): def _test_mcore_gpt_pruning_moe(ckpt_path, rank, size): + channel_divisor = 4 + num_layers = size - hidden_size = 128 - moe_ffn_hidden_size = 128 + hidden_size = channel_divisor * 4 + moe_ffn_hidden_size = channel_divisor * 2 num_moe_experts = 4 - moe_shared_expert_intermediate_size = 256 - max_sequence_length = 16 - vocab_size = 64 + moe_shared_expert_intermediate_size = channel_divisor * 4 + max_sequence_length = 8 + vocab_size = 16 batch_size = 2 def _get_model(initialize_megatron=True): @@ -513,12 +518,11 @@ def forward_loop(m): "num_moe_experts": pruned_num_moe_experts, } - mtp.prune( + prune_minitron( model, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config={"scores_path": ckpt_path, "forward_loop": forward_loop}, + export_config, + {"scores_path": ckpt_path, "forward_loop": forward_loop}, + channel_divisor, ) # Assert weights are pruned correctly @@ -554,13 +558,7 @@ def forward_loop(m): # Assert re-pruning from scores_path works without running the forward loop again model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) - mtp.prune( - model_rerun, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config={"scores_path": ckpt_path}, - ) + prune_minitron(model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) assert torch.allclose(output, output_rerun, atol=1e-5) diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 79a22809e..d6fa9400b 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -28,24 +28,28 @@ run_mcore_inference_with_dummy_input, ) from _test_utils.torch.misc import compare_outputs, set_seed +from _test_utils.torch.nas_prune.minitron_common import prune_minitron from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn -import modelopt.torch.prune as mtp from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, _convert_model_to_dynamic_space, + get_mcore_minitron_config, ) SEED = 1234 def _test_mcore_mamba_parameter_sorting(rank, size): + # Use relatively bigger model here for more accurate test for sorting + channel_divisor = 64 + num_layers = size hybrid_override_pattern = "M" * size - hidden_size = 256 - mamba_state_dim = 64 + hidden_size = channel_divisor * 4 + mamba_state_dim = channel_divisor mamba_head_dim = 16 mamba_num_groups = 2 max_sequence_length = 32 @@ -73,7 +77,9 @@ def _test_mcore_mamba_parameter_sorting(rank, size): m.weight.data = torch.randn_like(m.weight) model.eval() - dynamic_space = _convert_model_to_dynamic_space(model) + dynamic_space = _convert_model_to_dynamic_space( + model, get_mcore_minitron_config(channel_divisor) + ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks # Compute activations for sorting @@ -101,7 +107,7 @@ def _test_mcore_mamba_parameter_sorting(rank, size): compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) -def test_mcore_mamba_parameter_sorting(need_2_gpus): +def test_mcore_mamba_parameter_sorting(): set_seed(SEED) spawn_multiprocess_job( size=torch.cuda.device_count(), @@ -111,15 +117,18 @@ def test_mcore_mamba_parameter_sorting(need_2_gpus): def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size): + channel_divisor = 4 + num_layers = min(size * 2, 8) - hidden_size = 256 - ffn_hidden_size = 128 + hidden_size = channel_divisor * 8 + ffn_hidden_size = channel_divisor * 2 num_attention_heads = 8 num_query_groups = 4 - mamba_state_dim = 64 - mamba_head_dim = 16 + mamba_state_dim = channel_divisor * 2 + mamba_head_dim = channel_divisor * 2 mamba_num_groups = 2 num_moe_experts = 8 + vocab_size = 32 batch_size = 2 def _get_model(initialize_megatron=True): @@ -138,6 +147,7 @@ def _get_model(initialize_megatron=True): moe_ffn_hidden_size=ffn_hidden_size, moe_shared_expert_intermediate_size=ffn_hidden_size, num_moe_experts=num_moe_experts, + vocab_size=vocab_size, ).cuda() return model @@ -176,12 +186,11 @@ def forward_loop(m): "moe_shared_expert_intermediate_size": pruned_ffn_hidden_size, "num_moe_experts": pruned_num_moe_experts, } - mtp.prune( + prune_minitron( model, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config={"forward_loop": forward_loop, "scores_path": ckpt_path}, + export_config, + {"forward_loop": forward_loop, "scores_path": ckpt_path}, + channel_divisor, ) # Assert weights are pruned correctly @@ -211,13 +220,7 @@ def forward_loop(m): # Assert re-pruning from scores_path works without running the forward loop again model = _get_model(initialize_megatron=False) - mtp.prune( - model, - mode="mcore_minitron", - constraints={"export_config": export_config}, - dummy_input=None, # Not used - config={"scores_path": ckpt_path}, - ) + prune_minitron(model, export_config, {"scores_path": ckpt_path}, channel_divisor) def test_mcore_mamba_hybrid_pruning(tmp_path): diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py index d437282d6..ce762267a 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -17,7 +17,7 @@ import pytest import torch -from _test_utils.torch_sparsity.sparse_attention_common import ( +from _test_utils.torch.sparsity.sparse_attention_common import ( FLASH_SKIP_SOFTMAX_DEFAULT_CFG, SimpleAttentionModel, SimpleTransformerEncoder, diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index d93e929dc..6fcad9bb8 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -20,7 +20,7 @@ pytest.importorskip("transformers") import torch.nn as nn -from _test_utils.torch_sparsity.sparse_attention_common import ( +from _test_utils.torch.sparsity.sparse_attention_common import ( FLASH_SKIP_SOFTMAX_DEFAULT_CFG, SimpleAttentionModel, SimpleTransformerEncoderLayer,