-
Notifications
You must be signed in to change notification settings - Fork 312
Add YAML recipe system for model optimization #1013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sungsooha
wants to merge
18
commits into
NVIDIA:main
Choose a base branch
from
sungsooha:add-recipe-system
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
5ee6d3d
Add recipe system: schema, pipeline planner, experiment controller
sungsooha 167a14b
Validate QAT recipes require training config
sungsooha ba73484
Narrow preset fallback to ModuleNotFoundError only
sungsooha 60bb0eb
Narrow conftest import check to ModuleNotFoundError
sungsooha e8c983d
Address PR review feedback
sungsooha ace206a
Use model_dump for pipeline planners to preserve extra fields
sungsooha f2c868a
Replace Python _CFG dependency with YAML-based preset loading
sungsooha 0622941
Address PR review feedback and improve test coverage
sungsooha 0e7d561
Reorganize tests to mirror source file structure
sungsooha ee0ab8d
Mirror source directory structure in tests
sungsooha db59117
Fix ruff format: wrap long function signature in presets.py
sungsooha 6acdb3b
Forward-compatible PR #1000 alignment: load_config support, list type…
sungsooha afcf409
Refactor: extract shared utils (make_serializable, try_import_load_co…
sungsooha b066c60
Move make_serializable test to test_utils.py, add coverage
sungsooha 97f9234
Improve test quality: public API tests, remove redundancy, add coverage
sungsooha 7d957f0
Fix flaky distributed test exit by destroying process group
sungsooha 66a2e6e
Revert "Fix flaky distributed test exit by destroying process group"
sungsooha 76c7614
Narrow PR to recipe schema only: remove pipeline/CLI/examples, add YA…
sungsooha File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Recipe system for NVIDIA Model Optimizer. | ||
|
|
||
| Usage: | ||
| from modelopt.torch.recipes import load_recipe | ||
|
|
||
| # Load a recipe YAML file | ||
| result = load_recipe("path/to/recipe.yaml") | ||
|
|
||
| # For quantization recipes: | ||
| config = result["quantize_config"] # dict for mtq.quantize() | ||
| model = mtq.quantize(model, config, forward_loop=forward_loop) | ||
|
|
||
| # For auto-quantize recipes: | ||
| kwargs = result["auto_quantize_kwargs"] # kwargs for mtq.auto_quantize() | ||
| model, state = mtq.auto_quantize(model, **kwargs, data_loader=..., ...) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import yaml | ||
|
|
||
| from .schema import ( | ||
| FORMAT_REGISTRY, | ||
| KV_FORMAT_REGISTRY, | ||
| RecipeConfig, | ||
| get_preset, | ||
| get_preset_info, | ||
| get_preset_source, | ||
| list_presets, | ||
| resolve_recipe, | ||
| ) | ||
|
|
||
|
|
||
| def load_recipe(path: str | Path) -> dict[str, Any]: | ||
| """Load a YAML recipe and resolve it to mtq-compatible config dicts. | ||
|
|
||
| Args: | ||
| path: Path to the recipe YAML file. | ||
|
|
||
| Returns: | ||
| A dict with keys depending on the recipe type: | ||
| - "quantize_config": config dict for mtq.quantize() | ||
| - "auto_quantize_kwargs": kwargs dict for mtq.auto_quantize() | ||
| - "calibration": calibration params dict (if specified) | ||
| - "export": export params dict (if specified) | ||
| """ | ||
| path = Path(path) | ||
| with open(path) as f: | ||
| raw = yaml.safe_load(f) | ||
|
|
||
| recipe = RecipeConfig.model_validate(raw) | ||
| return resolve_recipe(recipe) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "FORMAT_REGISTRY", | ||
| "KV_FORMAT_REGISTRY", | ||
| "RecipeConfig", | ||
| "get_preset", | ||
| "get_preset_info", | ||
| "get_preset_source", | ||
| "list_presets", | ||
| "load_recipe", | ||
| "resolve_recipe", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Recipe schema — models, formats, presets, and resolution.""" | ||
|
|
||
| from .formats import FORMAT_REGISTRY, KV_FORMAT_REGISTRY | ||
| from .models import RecipeConfig | ||
| from .presets import ( | ||
| PRESET_YAML_MAP, | ||
| get_preset, | ||
| get_preset_info, | ||
| get_preset_source, | ||
| list_presets, | ||
| load_recipe_from_yaml, | ||
| ) | ||
| from .resolver import resolve_recipe | ||
|
|
||
| __all__ = [ | ||
| "FORMAT_REGISTRY", | ||
| "KV_FORMAT_REGISTRY", | ||
| "PRESET_YAML_MAP", | ||
| "RecipeConfig", | ||
| "get_preset", | ||
| "get_preset_info", | ||
| "get_preset_source", | ||
| "list_presets", | ||
| "load_recipe_from_yaml", | ||
| "resolve_recipe", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,250 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Format registry: maps human-readable format names to quantizer attribute kwargs. | ||
|
|
||
| Each entry has separate "weight" and "activation" defaults since they sometimes differ | ||
| (e.g., int8 weights use axis=0, activations use axis=None). | ||
|
|
||
| When PR #1000's load_config() is available, registries are loaded from YAML fragments | ||
| for automatic forward compatibility. Otherwise, falls back to inline definitions. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import copy | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Mapping from our format names to PR #1000's YAML fragment paths (without extension). | ||
| _FORMAT_YAML_MAP: dict[str, str] = { | ||
| "fp8": "configs/ptq/w8a8_fp8_fp8", | ||
| "nvfp4": "configs/ptq/w4a4_nvfp4_nvfp4", | ||
| "int8": "configs/ptq/w8a8_int8_per_channel_int8", | ||
| "int4": "configs/ptq/w4_int4_blockwise", | ||
| "mxfp8": "configs/ptq/w8a8_mxfp8_mxfp8", | ||
| "mxfp6": "configs/ptq/w6a6_mxfp6_mxfp6", | ||
| "mxfp4": "configs/ptq/w4a4_mxfp4_mxfp4", | ||
| } | ||
|
|
||
| _KV_FORMAT_YAML_MAP: dict[str, str] = { | ||
| "fp8": "configs/ptq/kv_fp8", | ||
| "nvfp4": "configs/ptq/kv_nvfp4", | ||
| "fp8_affine": "configs/ptq/kv_fp8_affine", | ||
| "nvfp4_affine": "configs/ptq/kv_nvfp4_affine", | ||
| "nvfp4_rotate": "configs/ptq/kv_nvfp4_rotate", | ||
| } | ||
|
|
||
| # Fallback values when PR #1000's load_config is not available. | ||
| # Uses lists (not tuples) to match PR #1000's OmegaConf output convention. | ||
| _FALLBACK_FORMAT_REGISTRY: dict[str, dict[str, dict[str, Any]]] = { | ||
| "fp8": { | ||
| "weight": {"num_bits": [4, 3], "axis": None}, | ||
| "activation": {"num_bits": [4, 3], "axis": None}, | ||
| }, | ||
| "nvfp4": { | ||
| "weight": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| "activation": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| }, | ||
| "int8": { | ||
| "weight": {"num_bits": 8, "axis": 0}, | ||
| "activation": {"num_bits": 8, "axis": None}, | ||
| }, | ||
| "int4": { | ||
| "weight": {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, | ||
| "activation": {"enable": False}, | ||
| }, | ||
| "mxfp8": { | ||
| "weight": { | ||
| "num_bits": [4, 3], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| "activation": { | ||
| "num_bits": [4, 3], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| }, | ||
| "mxfp6": { | ||
| "weight": { | ||
| "num_bits": [3, 2], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| "activation": { | ||
| "num_bits": [3, 2], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| }, | ||
| "mxfp4": { | ||
| "weight": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| "activation": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, | ||
| "enable": True, | ||
| }, | ||
| }, | ||
| } | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| _FALLBACK_KV_FORMAT_REGISTRY: dict[str, dict[str, Any]] = { | ||
| "fp8": { | ||
| "*[kv]_bmm_quantizer": {"num_bits": [4, 3], "axis": None, "enable": True}, | ||
| "default": {"enable": False}, | ||
| }, | ||
| "nvfp4": { | ||
| "*[kv]_bmm_quantizer": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| "fp8_affine": { | ||
| "*[kv]_bmm_quantizer": { | ||
| "num_bits": [4, 3], | ||
| "axis": None, | ||
| "enable": True, | ||
| "bias": {-2: None, -4: None, "type": "static"}, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| "nvfp4_affine": { | ||
| "*[kv]_bmm_quantizer": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| "bias": {-2: None, -4: None, "type": "static"}, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| "nvfp4_rotate": { | ||
| "*q_bmm_quantizer": {"enable": False, "rotate": True}, | ||
| "*k_bmm_quantizer": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| "rotate": True, | ||
| }, | ||
| "*v_bmm_quantizer": { | ||
| "num_bits": [2, 1], | ||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def _try_load_format_registry_from_yaml() -> dict[str, dict[str, dict[str, Any]]] | None: | ||
| """Try to load FORMAT_REGISTRY from PR #1000's YAML fragments via load_config.""" | ||
| from modelopt.torch.recipes.utils import try_import_load_config | ||
|
|
||
| load_config = try_import_load_config() | ||
| if load_config is None: | ||
| return None | ||
|
|
||
| try: | ||
| registry: dict[str, dict[str, dict[str, Any]]] = {} | ||
| for name, yaml_path in _FORMAT_YAML_MAP.items(): | ||
| cfg = load_config(yaml_path) | ||
| qcfg = cfg.get("quant_cfg", {}) | ||
| registry[name] = { | ||
| "weight": qcfg.get("*weight_quantizer", {}), | ||
| "activation": qcfg.get("*input_quantizer", {}), | ||
| } | ||
| logger.debug("Loaded FORMAT_REGISTRY from %d YAML fragments", len(registry)) | ||
| return registry | ||
| except (ValueError, KeyError, TypeError) as exc: | ||
| logger.debug("Failed to load FORMAT_REGISTRY from YAML: %s", exc) | ||
| return None | ||
|
|
||
|
|
||
| def _try_load_kv_format_registry_from_yaml() -> dict[str, dict[str, Any]] | None: | ||
| """Try to load KV_FORMAT_REGISTRY from PR #1000's YAML fragments via load_config.""" | ||
| from modelopt.torch.recipes.utils import try_import_load_config | ||
|
|
||
| load_config = try_import_load_config() | ||
| if load_config is None: | ||
| return None | ||
|
|
||
| try: | ||
| registry: dict[str, dict[str, Any]] = {} | ||
| for name, yaml_path in _KV_FORMAT_YAML_MAP.items(): | ||
| cfg = load_config(yaml_path) | ||
| registry[name] = cfg.get("quant_cfg", cfg) | ||
| logger.debug("Loaded KV_FORMAT_REGISTRY from %d YAML fragments", len(registry)) | ||
| return registry | ||
| except (ValueError, KeyError, TypeError) as exc: | ||
| logger.debug("Failed to load KV_FORMAT_REGISTRY from YAML: %s", exc) | ||
| return None | ||
|
|
||
|
|
||
| def _build_format_registry() -> dict[str, dict[str, dict[str, Any]]]: | ||
| """Build FORMAT_REGISTRY: prefer YAML fragments, fall back to inline.""" | ||
| registry = _try_load_format_registry_from_yaml() | ||
| if registry is not None: | ||
| return registry | ||
| return copy.deepcopy(_FALLBACK_FORMAT_REGISTRY) | ||
|
|
||
|
|
||
| def _build_kv_format_registry() -> dict[str, dict[str, Any]]: | ||
| """Build KV_FORMAT_REGISTRY: prefer YAML fragments, fall back to inline.""" | ||
| registry = _try_load_kv_format_registry_from_yaml() | ||
| if registry is not None: | ||
| return registry | ||
| return copy.deepcopy(_FALLBACK_KV_FORMAT_REGISTRY) | ||
|
|
||
|
|
||
| # Module-level registries — loaded at import time with graceful fallback. | ||
| FORMAT_REGISTRY: dict[str, dict[str, dict[str, Any]]] = _build_format_registry() | ||
| KV_FORMAT_REGISTRY: dict[str, dict[str, Any]] = _build_kv_format_registry() | ||
|
|
||
|
|
||
| def get_format(name: str) -> dict[str, dict[str, Any]]: | ||
| """Look up a format by name. Raises KeyError if not found.""" | ||
| if name not in FORMAT_REGISTRY: | ||
| available = sorted(FORMAT_REGISTRY.keys()) | ||
| raise KeyError(f"Unknown format '{name}'. Available: {available}") | ||
| return FORMAT_REGISTRY[name] | ||
|
|
||
|
|
||
| def get_kv_format(name: str) -> dict[str, Any]: | ||
| """Look up a KV cache format by name. Raises KeyError if not found.""" | ||
| if name not in KV_FORMAT_REGISTRY: | ||
| available = sorted(KV_FORMAT_REGISTRY.keys()) | ||
| raise KeyError(f"Unknown KV cache format '{name}'. Available: {available}") | ||
| return KV_FORMAT_REGISTRY[name] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.