Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,10 @@ class MambaNumHeadsHp(TracedHp):
Need special handling for active_slice property to trim heads within each group.
"""

def __init__(
self, choices: Sequence[HPType], original: HPType | None = None, ngroups: int = 1
) -> None:
super().__init__(choices, original)
def __init__(self, nheads: int, ngroups: int = 1) -> None:
"""Initialize choices as multiples of ngroups."""
choices = [h * ngroups for h in range(1, nheads // ngroups + 1)]
super().__init__(choices)
self._ngroups = ngroups

@property
Expand Down Expand Up @@ -892,7 +892,7 @@ def _setup(self, *, hidden_size: TracedHp):
assert self.d_inner == self.nheads * self.headdim, "d_inner must be nheads * headdim"

# Register hyperparameters for Mamba heads and head dimensions
mamba_num_heads = MambaNumHeadsHp(list(range(1, self.nheads + 1)), ngroups=self.ngroups)
mamba_num_heads = MambaNumHeadsHp(self.nheads, self.ngroups)
mamba_head_dim = TracedHp(list(range(1, self.headdim + 1)))
d_inner = MambaDInnerHp(mamba_num_heads, mamba_head_dim)
bc = TracedHp([2 * self.ngroups * self.d_state]) # not configurable
Expand Down Expand Up @@ -967,19 +967,14 @@ def _setup(self, *, hidden_size: TracedHp):
def modify(
self,
*,
mamba_num_heads_divisor: int = 1,
mamba_head_dim_divisor: int = 1,
**kwargs, # Unused hparams
) -> None:
"""Modify Mamba hyperparameters."""
# Modify MambaMixer hparams
for hp_name, divisor in [
("mamba_num_heads", mamba_num_heads_divisor),
("mamba_head_dim", mamba_head_dim_divisor),
]:
hp = self.mixer.get_hparam(hp_name)
choices = {int(make_divisible(c, divisor)) for c in hp.choices}
hp.choices = list(set(hp.choices) & choices | {hp.original})
hp = self.mixer.get_hparam("mamba_head_dim")
choices = {int(make_divisible(c, mamba_head_dim_divisor)) for c in hp.choices}
hp.choices = list(set(hp.choices) & choices | {hp.original})

def export(self):
"""Export the dynamic module to a torch.nn.Module."""
Expand Down
25 changes: 24 additions & 1 deletion modelopt/torch/prune/plugins/mcore_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"MCoreMinitronModeDescriptor",
"MCoreMinitronSearcher",
"drop_mcore_language_model_layers",
"get_mcore_minitron_config",
]


Expand Down Expand Up @@ -310,7 +311,6 @@ def run_search(self) -> None:
"megatron.core.models.mamba.MambaModel": {
"hidden_size_divisor": 64,
"ffn_hidden_size_divisor": 64,
"mamba_num_heads_divisor": 4,
"mamba_head_dim_divisor": 4,
"num_moe_experts_divisor": 1,
}
Expand All @@ -324,6 +324,29 @@ def run_search(self) -> None:
)


def get_mcore_minitron_config(
channel_divisor: int = 64,
mamba_head_dim_divisor: int = 4,
num_moe_experts_divisor: int = 1,
) -> ModeloptBaseConfig:
"""Get a MCoreMinitronConfig with the given channel divisor instead of default."""
config = MCoreMinitronConfig()

def _set_divisors(c):
for k, v in c.items():
if isinstance(v, dict):
_set_divisors(v)
elif k in ["hidden_size_divisor", "ffn_hidden_size_divisor"]:
c[k] = channel_divisor
elif k == "mamba_head_dim_divisor":
c[k] = mamba_head_dim_divisor
elif k == "num_moe_experts_divisor":
c[k] = num_moe_experts_divisor

_set_divisors(config)
return config


def _convert_model_to_dynamic_space(
model: nn.Module, config: ModeloptBaseConfig | None = None
) -> DynamicSpace:
Expand Down
26 changes: 26 additions & 0 deletions tests/_test_utils/torch/nas_prune/minitron_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import modelopt.torch.prune as mtp


def prune_minitron(model, export_config, config, channel_divisor=64):
return mtp.prune(
model,
mode=[("mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config(channel_divisor))],
constraints={"export_config": export_config},
dummy_input=None, # Not used
config=config,
)
27 changes: 14 additions & 13 deletions tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
expand_head_indices,
)
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_minitron_config
from modelopt.torch.utils.random import centroid

SEED = 1234
Expand All @@ -56,13 +57,13 @@
def _test_gpt_search_space(
num_attention_heads, num_query_groups, activation_func, normalization, rank, size
):
channel_divisor = 64
channel_divisor = 4

num_layers = min(size * 2, 8)
hidden_size = 256
ffn_hidden_size = 128
max_sequence_length = 16
vocab_size = 64
hidden_size = channel_divisor * 4
ffn_hidden_size = channel_divisor * 2
max_sequence_length = 8
vocab_size = 32
batch_size = 2

model = get_mcore_gpt_model(
Expand All @@ -80,7 +81,7 @@ def _test_gpt_search_space(
normalization=normalization,
).cuda()

model = mtn.convert(model, "mcore_minitron")
model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))])

assert isinstance(model, _DynamicMCoreLanguageModel)
for m in model.modules():
Expand Down Expand Up @@ -153,17 +154,17 @@ def test_expand_head_indices():


def _test_gpt_moe_search_space(rank, size):
channel_divisor = 64
channel_divisor = 4

num_layers = min(size * 2, 8)
hidden_size = 256
hidden_size = channel_divisor * 4
num_attention_heads = 8
num_query_groups = 4
moe_ffn_hidden_size = 128
moe_ffn_hidden_size = channel_divisor * 2
num_moe_experts = 4
moe_shared_expert_intermediate_size = 256
max_sequence_length = 16
vocab_size = 64
moe_shared_expert_intermediate_size = channel_divisor * 4
max_sequence_length = 8
vocab_size = 32
batch_size = 2

model = get_mcore_gpt_model(
Expand All @@ -182,7 +183,7 @@ def _test_gpt_moe_search_space(rank, size):
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
).cuda()

model = mtn.convert(model, "mcore_minitron")
model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))])

moe = model.decoder.layers[0].mlp
assert isinstance(moe, _DynamicMoELayer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,23 @@
)
from modelopt.torch.nas.traced_hp import TracedHp
from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size
from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_minitron_config
from modelopt.torch.utils.random import centroid

SEED = 1234


def _test_mamba_search_space(rank, size):
channel_divisor = 64
mamba_num_heads_divisor = 4
channel_divisor = 4
mamba_head_dim_divisor = 4

num_layers = size
hybrid_override_pattern = "M" * size
hidden_size = 256
mamba_state_dim = 64
mamba_head_dim = 16
hidden_size = channel_divisor * 4
mamba_state_dim = channel_divisor
mamba_head_dim = mamba_head_dim_divisor * 2
mamba_num_groups = 2
max_sequence_length = 16
max_sequence_length = 8
vocab_size = 32
batch_size = 2

Expand All @@ -75,7 +75,7 @@ def _test_mamba_search_space(rank, size):
).cuda()
mamba_num_heads = model.decoder.layers[0].mixer.nheads

model = mtn.convert(model, "mcore_minitron")
model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))])

assert isinstance(model, _DynamicMCoreLanguageModel)
if is_pipeline_first_stage():
Expand All @@ -94,7 +94,7 @@ def _test_mamba_search_space(rank, size):

# NOTE: `search_space_size` does not reduce across TP/PP groups
ss_size_per_pp = search_space_size(model)
num_heads_choices = mamba_num_heads // mamba_num_heads_divisor
num_heads_choices = mamba_num_heads // mamba_num_groups
head_dim_choices = mamba_head_dim // mamba_head_dim_divisor
hidden_size_choices = hidden_size // channel_divisor
num_layers_per_pp = num_layers // size
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_mamba_search_space():


def test_mamba_num_heads_hp():
num_heads = MambaNumHeadsHp([2, 4, 6, 8], ngroups=2) # 4 heads per group
num_heads = MambaNumHeadsHp(8, ngroups=2) # 4 heads per group
assert num_heads.choices == [2, 4, 6, 8]
assert num_heads.active_slice == slice(8)

Expand Down
Loading