Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,13 @@ def get_inlined_mtp_prefixes(config: Any) -> list[str]:
def _keys_to_prefixes(keys: Iterable[str]) -> set[str]:
"""Invert separate-file MTP keys into the prefixes the exporter needs for exclude_modules.
``"mtp.fc.weight"`` → ``{"mtp"}``; ``"mtp.layers.0.q_proj.weight"`` →
``{"mtp", "mtp.layers.0"}``. Caller must filter out inlined keys; otherwise
``"model.layers.78.eh_proj.weight"`` would emit ``"model"`` as a prefix.
``{"mtp", "mtp.layers.0"}``. ``"model"`` top-level is dropped to avoid the
``"model*"`` wildcard covering the whole backbone.
"""
prefixes: set[str] = set()
for key in keys:
parts = key.split(".")
if parts:
if parts and parts[0] != "model":
prefixes.add(parts[0])
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
Expand Down
10 changes: 10 additions & 0 deletions tests/examples/llm_ptq/test_example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def test_load_mtp_weights_separate_indexed_shard(tmp_path):
assert set(orphans) == set(mtp_tensors)


def test_keys_to_prefixes_drops_model_top_level():
# nvbug 6108133: inlined keys like "model.layers.92.X" must NOT emit "model"
# as a top-level prefix (would become "model*" excluding the whole backbone).
out = example_utils._keys_to_prefixes(
["model.layers.92.eh_proj.weight", "mtp.fc.weight", "mtp.layers.0.q_proj.weight"]
)
assert "model" not in out
assert out == {"mtp", "mtp.layers.0", "model.layers.92"}


def test_load_mtp_weights_no_mtp_returns_empty(tmp_path):
# Also pins the ``num_nextn_predict_layers=None`` regression: some configs
# set the field explicitly to None, which must not crash ``int(None)``.
Expand Down
Loading