diff --git a/modelopt/torch/recipes/__init__.py b/modelopt/torch/recipes/__init__.py new file mode 100644 index 000000000..c630becea --- /dev/null +++ b/modelopt/torch/recipes/__init__.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recipe system for NVIDIA Model Optimizer. + +Usage: + from modelopt.torch.recipes import load_recipe + + # Load a recipe YAML file + result = load_recipe("path/to/recipe.yaml") + + # For quantization recipes: + config = result["quantize_config"] # dict for mtq.quantize() + model = mtq.quantize(model, config, forward_loop=forward_loop) + + # For auto-quantize recipes: + kwargs = result["auto_quantize_kwargs"] # kwargs for mtq.auto_quantize() + model, state = mtq.auto_quantize(model, **kwargs, data_loader=..., ...) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml + +from .schema import ( + FORMAT_REGISTRY, + KV_FORMAT_REGISTRY, + RecipeConfig, + get_preset, + get_preset_info, + get_preset_source, + list_presets, + resolve_recipe, +) + + +def load_recipe(path: str | Path) -> dict[str, Any]: + """Load a YAML recipe and resolve it to mtq-compatible config dicts. + + Args: + path: Path to the recipe YAML file. + + Returns: + A dict with keys depending on the recipe type: + - "quantize_config": config dict for mtq.quantize() + - "auto_quantize_kwargs": kwargs dict for mtq.auto_quantize() + - "calibration": calibration params dict (if specified) + - "export": export params dict (if specified) + """ + path = Path(path) + with open(path) as f: + raw = yaml.safe_load(f) + + recipe = RecipeConfig.model_validate(raw) + return resolve_recipe(recipe) + + +__all__ = [ + "FORMAT_REGISTRY", + "KV_FORMAT_REGISTRY", + "RecipeConfig", + "get_preset", + "get_preset_info", + "get_preset_source", + "list_presets", + "load_recipe", + "resolve_recipe", +] diff --git a/modelopt/torch/recipes/schema/__init__.py b/modelopt/torch/recipes/schema/__init__.py new file mode 100644 index 000000000..f499ee023 --- /dev/null +++ b/modelopt/torch/recipes/schema/__init__.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recipe schema — models, formats, presets, and resolution.""" + +from .formats import FORMAT_REGISTRY, KV_FORMAT_REGISTRY +from .models import RecipeConfig +from .presets import ( + PRESET_YAML_MAP, + get_preset, + get_preset_info, + get_preset_source, + list_presets, + load_recipe_from_yaml, +) +from .resolver import resolve_recipe + +__all__ = [ + "FORMAT_REGISTRY", + "KV_FORMAT_REGISTRY", + "PRESET_YAML_MAP", + "RecipeConfig", + "get_preset", + "get_preset_info", + "get_preset_source", + "list_presets", + "load_recipe_from_yaml", + "resolve_recipe", +] diff --git a/modelopt/torch/recipes/schema/formats.py b/modelopt/torch/recipes/schema/formats.py new file mode 100644 index 000000000..9c033bb27 --- /dev/null +++ b/modelopt/torch/recipes/schema/formats.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Format registry: maps human-readable format names to quantizer attribute kwargs. + +Each entry has separate "weight" and "activation" defaults since they sometimes differ +(e.g., int8 weights use axis=0, activations use axis=None). + +When PR #1000's load_config() is available, registries are loaded from YAML fragments +for automatic forward compatibility. Otherwise, falls back to inline definitions. +""" + +from __future__ import annotations + +import copy +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# Mapping from our format names to PR #1000's YAML fragment paths (without extension). +_FORMAT_YAML_MAP: dict[str, str] = { + "fp8": "configs/ptq/w8a8_fp8_fp8", + "nvfp4": "configs/ptq/w4a4_nvfp4_nvfp4", + "int8": "configs/ptq/w8a8_int8_per_channel_int8", + "int4": "configs/ptq/w4_int4_blockwise", + "mxfp8": "configs/ptq/w8a8_mxfp8_mxfp8", + "mxfp6": "configs/ptq/w6a6_mxfp6_mxfp6", + "mxfp4": "configs/ptq/w4a4_mxfp4_mxfp4", +} + +_KV_FORMAT_YAML_MAP: dict[str, str] = { + "fp8": "configs/ptq/kv_fp8", + "nvfp4": "configs/ptq/kv_nvfp4", + "fp8_affine": "configs/ptq/kv_fp8_affine", + "nvfp4_affine": "configs/ptq/kv_nvfp4_affine", + "nvfp4_rotate": "configs/ptq/kv_nvfp4_rotate", +} + +# Fallback values when PR #1000's load_config is not available. +# Uses lists (not tuples) to match PR #1000's OmegaConf output convention. +_FALLBACK_FORMAT_REGISTRY: dict[str, dict[str, dict[str, Any]]] = { + "fp8": { + "weight": {"num_bits": [4, 3], "axis": None}, + "activation": {"num_bits": [4, 3], "axis": None}, + }, + "nvfp4": { + "weight": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + }, + "activation": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + }, + }, + "int8": { + "weight": {"num_bits": 8, "axis": 0}, + "activation": {"num_bits": 8, "axis": None}, + }, + "int4": { + "weight": {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, + "activation": {"enable": False}, + }, + "mxfp8": { + "weight": { + "num_bits": [4, 3], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + "activation": { + "num_bits": [4, 3], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + }, + "mxfp6": { + "weight": { + "num_bits": [3, 2], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + "activation": { + "num_bits": [3, 2], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + }, + "mxfp4": { + "weight": { + "num_bits": [2, 1], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + "activation": { + "num_bits": [2, 1], + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": [8, 0]}, + "enable": True, + }, + }, +} + +_FALLBACK_KV_FORMAT_REGISTRY: dict[str, dict[str, Any]] = { + "fp8": { + "*[kv]_bmm_quantizer": {"num_bits": [4, 3], "axis": None, "enable": True}, + "default": {"enable": False}, + }, + "nvfp4": { + "*[kv]_bmm_quantizer": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + }, + "default": {"enable": False}, + }, + "fp8_affine": { + "*[kv]_bmm_quantizer": { + "num_bits": [4, 3], + "axis": None, + "enable": True, + "bias": {-2: None, -4: None, "type": "static"}, + }, + "default": {"enable": False}, + }, + "nvfp4_affine": { + "*[kv]_bmm_quantizer": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + "bias": {-2: None, -4: None, "type": "static"}, + }, + "default": {"enable": False}, + }, + "nvfp4_rotate": { + "*q_bmm_quantizer": {"enable": False, "rotate": True}, + "*k_bmm_quantizer": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + "rotate": True, + }, + "*v_bmm_quantizer": { + "num_bits": [2, 1], + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": [4, 3]}, + "axis": None, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +def _try_load_format_registry_from_yaml() -> dict[str, dict[str, dict[str, Any]]] | None: + """Try to load FORMAT_REGISTRY from PR #1000's YAML fragments via load_config.""" + from modelopt.torch.recipes.utils import try_import_load_config + + load_config = try_import_load_config() + if load_config is None: + return None + + try: + registry: dict[str, dict[str, dict[str, Any]]] = {} + for name, yaml_path in _FORMAT_YAML_MAP.items(): + cfg = load_config(yaml_path) + qcfg = cfg.get("quant_cfg", {}) + registry[name] = { + "weight": qcfg.get("*weight_quantizer", {}), + "activation": qcfg.get("*input_quantizer", {}), + } + logger.debug("Loaded FORMAT_REGISTRY from %d YAML fragments", len(registry)) + return registry + except (ValueError, KeyError, TypeError) as exc: + logger.debug("Failed to load FORMAT_REGISTRY from YAML: %s", exc) + return None + + +def _try_load_kv_format_registry_from_yaml() -> dict[str, dict[str, Any]] | None: + """Try to load KV_FORMAT_REGISTRY from PR #1000's YAML fragments via load_config.""" + from modelopt.torch.recipes.utils import try_import_load_config + + load_config = try_import_load_config() + if load_config is None: + return None + + try: + registry: dict[str, dict[str, Any]] = {} + for name, yaml_path in _KV_FORMAT_YAML_MAP.items(): + cfg = load_config(yaml_path) + registry[name] = cfg.get("quant_cfg", cfg) + logger.debug("Loaded KV_FORMAT_REGISTRY from %d YAML fragments", len(registry)) + return registry + except (ValueError, KeyError, TypeError) as exc: + logger.debug("Failed to load KV_FORMAT_REGISTRY from YAML: %s", exc) + return None + + +def _build_format_registry() -> dict[str, dict[str, dict[str, Any]]]: + """Build FORMAT_REGISTRY: prefer YAML fragments, fall back to inline.""" + registry = _try_load_format_registry_from_yaml() + if registry is not None: + return registry + return copy.deepcopy(_FALLBACK_FORMAT_REGISTRY) + + +def _build_kv_format_registry() -> dict[str, dict[str, Any]]: + """Build KV_FORMAT_REGISTRY: prefer YAML fragments, fall back to inline.""" + registry = _try_load_kv_format_registry_from_yaml() + if registry is not None: + return registry + return copy.deepcopy(_FALLBACK_KV_FORMAT_REGISTRY) + + +# Module-level registries — loaded at import time with graceful fallback. +FORMAT_REGISTRY: dict[str, dict[str, dict[str, Any]]] = _build_format_registry() +KV_FORMAT_REGISTRY: dict[str, dict[str, Any]] = _build_kv_format_registry() + + +def get_format(name: str) -> dict[str, dict[str, Any]]: + """Look up a format by name. Raises KeyError if not found.""" + if name not in FORMAT_REGISTRY: + available = sorted(FORMAT_REGISTRY.keys()) + raise KeyError(f"Unknown format '{name}'. Available: {available}") + return FORMAT_REGISTRY[name] + + +def get_kv_format(name: str) -> dict[str, Any]: + """Look up a KV cache format by name. Raises KeyError if not found.""" + if name not in KV_FORMAT_REGISTRY: + available = sorted(KV_FORMAT_REGISTRY.keys()) + raise KeyError(f"Unknown KV cache format '{name}'. Available: {available}") + return KV_FORMAT_REGISTRY[name] diff --git a/modelopt/torch/recipes/schema/models.py b/modelopt/torch/recipes/schema/models.py new file mode 100644 index 000000000..8a4da1a32 --- /dev/null +++ b/modelopt/torch/recipes/schema/models.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic schema models for recipe YAML validation. + +These models define the structure of recipe YAML files. The resolver +(resolver.py) translates validated schema objects into the config dicts +that ModelOpt APIs accept (mtq.quantize, distill.convert, etc.). +""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, model_validator + + +class CalibrationConfig(BaseModel): + """Calibration data configuration.""" + + dataset: str | list[str] = "cnn_dailymail" + num_samples: int | list[int] = 512 + max_sequence_length: int = 512 + batch_size: int = 1 + + +class KVCacheConfig(BaseModel): + """KV cache quantization configuration.""" + + format: str # fp8, nvfp4 + + +class AlgorithmConfig(BaseModel): + """Quantization algorithm configuration. + + Extra fields (e.g., fp8_scale_sweep) are passed through to the algorithm dict. + """ + + method: str # max, smoothquant, awq_lite, awq_clip, awq_full, mse, etc. + alpha_step: float | None = None + max_co_batch_size: int | None = None + + model_config = {"extra": "allow"} + + +class QuantizerSpec(BaseModel): + """Specifies quantization for weights or activations.""" + + format: str | None = None # human-readable: fp8, nvfp4, int4, int8 + num_bits: int | list[int] | None = None # expert-mode escape hatch + axis: int | None = None + block_sizes: dict[str, Any] | None = None + enable: bool = True + stages: list[QuantizerSpec] | None = None # for multi-stage (W4A8) + + +class OverrideEntry(BaseModel): + """Per-layer or per-module-class override.""" + + pattern: str | None = None # glob: "*lm_head*" + module_class: str | None = None # class: "nn.LayerNorm" + enable: bool | None = None + format: str | None = None + scale_type: Literal["static", "dynamic"] | None = None # shorthand for block_sizes.type + weights: QuantizerSpec | None = None + activations: QuantizerSpec | None = None + num_bits: int | list[int] | None = None + axis: int | None = None + + @model_validator(mode="after") + def validate_has_selector(self): + """Ensure at least one of pattern or module_class is set.""" + if not self.pattern and not self.module_class: + raise ValueError("Override must specify 'pattern' or 'module_class' to target.") + return self + + +class QuantizationSection(BaseModel): + """Quantization technique configuration. + + Covers both PTQ and QAT. For QAT, set mode="qat" and provide training config. + PTQ: calibrate → quantize (minutes). QAT: quantize → fine-tune (hours, better accuracy). + """ + + mode: Literal["ptq", "qat"] = "ptq" + preset: str | None = None + weights: QuantizerSpec | None = None + activations: QuantizerSpec | None = None + algorithm: AlgorithmConfig | str | None = None + kv_cache: KVCacheConfig | None = None + calibration: CalibrationConfig | None = None + training: TrainingConfig | None = None # QAT training config + overrides: list[OverrideEntry] = [] + disabled_patterns: list[str] = [] + + @model_validator(mode="after") + def validate_preset_or_custom(self): + """Ensure preset and explicit quantizer specs are not both set.""" + if self.preset and (self.weights or self.activations): + raise ValueError( + "Cannot specify both 'preset' and 'weights'/'activations'. " + "Use preset with overrides, or specify weights/activations from scratch." + ) + if self.mode == "qat" and self.training is None: + raise ValueError("QAT mode requires a 'training' configuration.") + return self + + +class AutoQuantizeFormatEntry(BaseModel): + """A candidate format for auto-quantize search.""" + + preset: str # e.g., "nvfp4_awq", "fp8" + + +class AutoQuantizeSection(BaseModel): + """Auto-quantize configuration for per-layer format search.""" + + effective_bits: float + formats: list[AutoQuantizeFormatEntry] + method: str = "gradient" + num_calib_steps: int = 512 + num_score_steps: int = 128 + disabled_patterns: list[str] = [] + kv_cache: KVCacheConfig | None = None + calibration: CalibrationConfig | None = None + + +class TrainingConfig(BaseModel): + """Training configuration for QAT and distillation.""" + + learning_rate: float = 1e-5 + num_epochs: int = 1 + batch_size: int = 1 + max_steps: int | None = None + warmup_steps: int = 0 + weight_decay: float = 0.0 + gradient_accumulation_steps: int = 1 + + model_config = {"extra": "allow"} + + +class DistillationSection(BaseModel): + """Knowledge distillation configuration. + + Maps to modelopt.torch.distill.convert() with KDLossConfig. + """ + + teacher: str # teacher model path (e.g., "meta-llama/Llama-3-70B") + criterion: str = "kl_div" # kl_div, mse, cross_entropy + kd_loss_weight: float = 0.5 # weight for KD loss vs student loss + layer_pairs: list[dict[str, str]] | None = None # layer-wise distillation + training: TrainingConfig | None = None + calibration: CalibrationConfig | None = None + + model_config = {"extra": "allow"} + + +class SparsitySection(BaseModel): + """Sparsity configuration. + + Maps to modelopt.torch.sparsity APIs. + """ + + method: str # sparse_gpt, magnitude, wanda + sparsity: float = 0.5 # target sparsity ratio + pattern: str = "unstructured" # unstructured, 2:4 + calibration: CalibrationConfig | None = None + + model_config = {"extra": "allow"} + + @model_validator(mode="after") + def validate_sparsity_range(self): + """Ensure sparsity is in valid range (0, 1].""" + if not (0.0 < self.sparsity <= 1.0): + raise ValueError(f"Sparsity must be in (0.0, 1.0], got {self.sparsity}") + return self + + +class ExportConfig(BaseModel): + """Export configuration.""" + + format: Literal["hf", "tensorrt_llm"] = "hf" + output_dir: str = "./output" + tensor_parallel: int = 1 + pipeline_parallel: int = 1 + + +class ModelConfig(BaseModel): + """Model specification.""" + + path: str + trust_remote_code: bool = False + attn_implementation: str | None = None + + +class RecipeMetadata(BaseModel): + """Optional recipe metadata.""" + + name: str | None = None + description: str | None = None + author: str | None = None + tags: list[str] = [] + + +class PruningSection(BaseModel): + """Pruning configuration. + + Maps to modelopt.torch.prune.prune(model, mode, constraints, dummy_input, config). + Modes: fastnas, gradnas, mcore_minitron. + """ + + mode: str # fastnas, gradnas, mcore_minitron + constraints: dict[str, Any] = {} # flops, params, export_config + calibration: CalibrationConfig | None = None + training: TrainingConfig | None = None + + model_config = {"extra": "allow"} + + +class RecipeConfig(BaseModel): + """Top-level recipe schema. + + Techniques are composable — a recipe can combine quantization + distillation, + sparsity + quantization, etc. Each technique owns its own calibration/training config. + Execution order: pruning → sparsity → quantization → distillation. + """ + + version: str = "1.0" + metadata: RecipeMetadata | None = None + model: ModelConfig | None = None + pruning: PruningSection | None = None + quantization: QuantizationSection | None = None + auto_quantize: AutoQuantizeSection | None = None + distillation: DistillationSection | None = None + sparsity: SparsitySection | None = None + export: ExportConfig | None = None + + @model_validator(mode="after") + def validate_exclusive_sections(self): + """Ensure quantization and auto_quantize are mutually exclusive.""" + if self.quantization and self.auto_quantize: + raise ValueError( + "'quantization' and 'auto_quantize' are mutually exclusive. Use one or the other." + ) + return self diff --git a/modelopt/torch/recipes/schema/presets.py b/modelopt/torch/recipes/schema/presets.py new file mode 100644 index 000000000..b58a7de2d --- /dev/null +++ b/modelopt/torch/recipes/schema/presets.py @@ -0,0 +1,389 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preset registry with tiered resolution. + +Tier 1a (preferred): Use PR #1000's load_config() when available + - Canonical OmegaConf-based __base__ resolution + - Forward-compatible: auto-adopts when PR #1000 merges + - Falls through gracefully if load_config is not yet available + +Tier 1b: Load from modelopt_recipes/ YAML fragments with our own loader + - Lightweight __base__ resolution (no OmegaConf dependency) + - Used when modelopt_recipes is installed but load_config is not available + +Tier 2 (fallback): Live import from modelopt.torch.quantization.config + - Gets preset dicts from Python constants (deprecated — team removing these) + - Both tiers are in the same repo, so at least one is always available +""" + +from __future__ import annotations + +import copy +import logging +from pathlib import Path +from typing import Any + +import yaml + +from modelopt.torch.recipes.utils import load_yaml_with_bases + +logger = logging.getLogger(__name__) + +_PRESET_REGISTRY: dict[str, dict[str, Any]] | None = None +_PRESET_METADATA: dict[str, dict[str, str]] = {} +_PRESET_SOURCE: str = "unknown" + +# Mapping from preset name to the composed recipe directory in modelopt_recipes/. +# Each entry is a directory under general/ptq/ containing model_quant.yml + kv_quant.yml. +# The model_quant.yml uses __base__ to compose atomic fragments (base + quantizer + algorithm). +PRESET_YAML_MAP: dict[str, str] = { + # Core formats + "fp8": "general/ptq/fp8_default-fp8_kv", + "fp8_pc_pt": "general/ptq/fp8_per_channel_per_token-fp8_kv", + "fp8_pb_wo": "general/ptq/fp8_2d_blockwise_weight_only-fp8_kv", + "int8": "general/ptq/int8_default-fp8_kv", + "int8_sq": "general/ptq/int8_smoothquant-fp8_kv", + "int8_wo": "general/ptq/int8_weight_only-fp8_kv", + "int4": "general/ptq/int4_blockwise_weight_only-fp8_kv", + "int4_awq": "general/ptq/int4_awq-fp8_kv", + # NVFP4 family + "nvfp4": "general/ptq/nvfp4_default-fp8_kv", + "nvfp4_awq": "general/ptq/nvfp4_awq_lite-fp8_kv", + "nvfp4_awq_lite": "general/ptq/nvfp4_awq_lite-fp8_kv", + "nvfp4_awq_clip": "general/ptq/nvfp4_awq_clip-fp8_kv", + "nvfp4_awq_full": "general/ptq/nvfp4_awq_full-fp8_kv", + "nvfp4_mse": "general/ptq/nvfp4_w4a4_weight_mse_fp8_sweep-fp8_kv", + "nvfp4_local_hessian": "general/ptq/nvfp4_w4a4_weight_local_hessian-fp8_kv", + "nvfp4_fp8_mha": "general/ptq/nvfp4_fp8_mha-fp8_kv", + "nvfp4_svdquant": "general/ptq/nvfp4_svdquant_default-fp8_kv", + "nvfp4_mlp_only": "general/ptq/nvfp4_mlp_only-fp8_kv", + "nvfp4_mlp_wo": "general/ptq/nvfp4_mlp_weight_only-fp8_kv", + "nvfp4_omlp_only": "general/ptq/nvfp4_omlp_only-fp8_kv", + # W4A8 variants + "w4a8_awq": "general/ptq/w4a8_awq_beta-fp8_kv", + "w4a8_nvfp4_fp8": "general/ptq/w4a8_nvfp4_fp8-fp8_kv", + "w4a8_mxfp4_fp8": "general/ptq/w4a8_mxfp4_fp8-fp8_kv", + # MX formats + "mxfp8": "general/ptq/mxfp8_default-fp8_kv", + "mxfp6": "general/ptq/mxfp6_default-fp8_kv", + "mxfp4": "general/ptq/mxfp4_default-fp8_kv", + "mxint8": "general/ptq/mxint8_default-fp8_kv", + "mxfp4_mlp_wo": "general/ptq/mxfp4_mlp_weight_only-fp8_kv", + # Mamba MOE — mamba_moe_fp8_aggressive is Tier-2-only (no YAML directory in PR #1000) + "mamba_moe_fp8_aggressive": "general/ptq/mamba_moe_fp8_aggressive-fp8_kv", + "mamba_moe_fp8_conservative": "general/ptq/mamba_moe_fp8_conservative-fp8_kv", + "mamba_moe_nvfp4_aggressive": "general/ptq/mamba_moe_nvfp4_aggressive-fp8_kv", + "mamba_moe_nvfp4_conservative": "general/ptq/mamba_moe_nvfp4_conservative-fp8_kv", +} + +# Presets that only exist as Python constants (no YAML directory in PR #1000). +# These are skipped during YAML loading and filled from Tier 2 instead. +_TIER2_ONLY_PRESETS: set[str] = {"mamba_moe_fp8_aggressive"} + + +def _get_load_config(): # pragma: no cover + """Try to import PR #1000's load_config and verify it works. + + Returns the function or None. Unlike try_import_load_config() (which only + checks importability), this also verifies the YAML fragments are present. + """ + from modelopt.torch.recipes.utils import try_import_load_config + + load_config = try_import_load_config() + if load_config is None: + return None + try: + load_config("configs/ptq/base") + return load_config + except (ValueError, AttributeError, TypeError): + return None + + +def load_recipe_from_yaml( + recipe_dir: str, recipes_root: Path +) -> dict[str, Any]: # pragma: no cover + """Load a composed recipe directory into a preset config dict. + + A recipe directory contains: + - model_quant.yml: main quantizer config (__base__ inheritance) + - kv_quant.yml (optional): KV cache config (__base__ inheritance) + - recipe.yml: metadata (recipe_type, description) + + Returns a dict matching the format of Python *_CFG constants: + {"quant_cfg": {...}, "algorithm": "..."} + """ + recipe_path = recipes_root / recipe_dir + + # Load model quantizer config + model_quant_path = recipe_path / "model_quant.yml" + if not model_quant_path.is_file(): + raise FileNotFoundError(f"model_quant.yml not found in {recipe_path}") + config = load_yaml_with_bases(model_quant_path, recipes_root) + + # Load KV cache config if present and merge + kv_quant_path = recipe_path / "kv_quant.yml" + if kv_quant_path.is_file(): + kv_config = load_yaml_with_bases(kv_quant_path, recipes_root) + if "quant_cfg" in kv_config: + config.setdefault("quant_cfg", {}).update(kv_config["quant_cfg"]) + + return config + + +def _load_recipe_metadata( + recipe_dir: str, recipes_root: Path +) -> dict[str, str] | None: # pragma: no cover + """Load recipe.yml metadata from a composed recipe directory.""" + recipe_yml = recipes_root / recipe_dir / "recipe.yml" + if not recipe_yml.is_file(): + return None + try: + with open(recipe_yml) as f: + data = yaml.safe_load(f) or {} + return { + "description": data.get("description", ""), + "recipe_type": data.get("recipe_type", "ptq"), + } + except (yaml.YAMLError, OSError): + return None + + +def _try_load_yaml_registry_via_load_config( + load_config_fn, +) -> dict[str, dict[str, Any]] | None: # pragma: no cover + """Tier 1a: Load presets using PR #1000's load_config (canonical OmegaConf merge).""" + registry: dict[str, dict[str, Any]] = {} + for preset_name, recipe_dir in PRESET_YAML_MAP.items(): + if preset_name in _TIER2_ONLY_PRESETS: + continue + try: + config = load_config_fn(f"{recipe_dir}/model_quant") + except (ValueError, FileNotFoundError): + logger.debug("load_config failed for preset '%s' — aborting Tier 1a", preset_name) + return None + + # Load KV quant if present + try: + kv_config = load_config_fn(f"{recipe_dir}/kv_quant") + if "quant_cfg" in kv_config: + config.setdefault("quant_cfg", {}).update(kv_config["quant_cfg"]) + except (ValueError, FileNotFoundError): + pass + + # Load recipe.yml metadata + try: + meta = load_config_fn(f"{recipe_dir}/recipe") + _PRESET_METADATA[preset_name] = { + "description": meta.get("description", ""), + "recipe_type": meta.get("recipe_type", "ptq"), + } + except (ValueError, FileNotFoundError): + pass + + registry[preset_name] = config + + return registry + + +def _try_load_yaml_registry() -> dict[str, dict[str, Any]] | None: # pragma: no cover + """Attempt to load presets from YAML fragments. + + Tries PR #1000's load_config first (Tier 1a), then our own loader (Tier 1b). + Returns the complete registry dict, or None if neither approach works. + """ + # Tier 1a: Use PR #1000's load_config (canonical, OmegaConf merge) + load_config_fn = _get_load_config() + if load_config_fn is not None: + registry = _try_load_yaml_registry_via_load_config(load_config_fn) + if registry is not None: + return registry + + # Tier 1b: Our lightweight YAML loader + try: + from importlib.resources import files + + recipes_pkg = files("modelopt_recipes") + except (ModuleNotFoundError, TypeError): + return None + + recipes_root = Path(str(recipes_pkg)) + if not recipes_root.is_dir(): + return None + + yaml_registry: dict[str, dict[str, Any]] = {} + for preset_name, recipe_dir in PRESET_YAML_MAP.items(): + if preset_name in _TIER2_ONLY_PRESETS: + continue + config = _try_load_single_yaml_preset(preset_name, recipe_dir, recipes_root) + if config is None: + return None # Partial load is worse than no load — fall through to next tier + + # Load metadata + meta = _load_recipe_metadata(recipe_dir, recipes_root) + if meta: + _PRESET_METADATA[preset_name] = meta + + yaml_registry[preset_name] = config + + return yaml_registry + + +def _try_load_single_yaml_preset( + preset_name: str, recipe_dir: str, recipes_root: Path +) -> dict[str, Any] | None: # pragma: no cover + """Load a single preset from YAML, returning None on failure.""" + try: + return load_recipe_from_yaml(recipe_dir, recipes_root) + except (FileNotFoundError, yaml.YAMLError) as exc: + logger.debug("Failed to load YAML preset '%s': %s", preset_name, exc) + return None + + +def _fill_tier2_only_presets(registry: dict[str, dict[str, Any]]) -> None: # pragma: no cover + """Load Tier-2-only presets from Python constants.""" + try: + import modelopt.torch.quantization.config as _cfg_mod + except (ImportError, ModuleNotFoundError): + return + + _tier2_attr_map: dict[str, str] = { + "mamba_moe_fp8_aggressive": "MAMBA_MOE_FP8_AGGRESSIVE_CFG", + } + for name, attr in _tier2_attr_map.items(): + if name not in registry: + try: + registry[name] = getattr(_cfg_mod, attr) + except AttributeError: + logger.debug("Tier-2-only preset '%s' not found as %s", name, attr) + + +def _try_load_python_registry() -> dict[str, dict[str, Any]] | None: # pragma: no cover + """Attempt to load presets from Python constants (deprecated). + + Returns the registry dict, or None if the constants are not available. + This tier will be removed when the team removes *_CFG constants. + """ + try: + import modelopt.torch.quantization.config as _cfg_mod + except ModuleNotFoundError: + return None + + # Build registry from Python constants. If any attribute is missing + # (e.g., team has started removing constants), fall through gracefully. + try: + return { + # Core formats + "fp8": _cfg_mod.FP8_DEFAULT_CFG, + "fp8_pc_pt": _cfg_mod.FP8_PER_CHANNEL_PER_TOKEN_CFG, + "fp8_pb_wo": _cfg_mod.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + "int8": _cfg_mod.INT8_DEFAULT_CFG, + "int8_sq": _cfg_mod.INT8_SMOOTHQUANT_CFG, + "int8_wo": _cfg_mod.INT8_WEIGHT_ONLY_CFG, + "int4": _cfg_mod.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + "int4_awq": _cfg_mod.INT4_AWQ_CFG, + # NVFP4 family + "nvfp4": _cfg_mod.NVFP4_DEFAULT_CFG, + "nvfp4_awq": _cfg_mod.NVFP4_AWQ_LITE_CFG, + "nvfp4_awq_lite": _cfg_mod.NVFP4_AWQ_LITE_CFG, + "nvfp4_awq_clip": _cfg_mod.NVFP4_AWQ_CLIP_CFG, + "nvfp4_awq_full": _cfg_mod.NVFP4_AWQ_FULL_CFG, + "nvfp4_mse": _cfg_mod.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, + "nvfp4_local_hessian": _cfg_mod.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, + "nvfp4_fp8_mha": _cfg_mod.NVFP4_FP8_MHA_CONFIG, + "nvfp4_svdquant": _cfg_mod.NVFP4_SVDQUANT_DEFAULT_CFG, + "nvfp4_mlp_only": _cfg_mod.NVFP4_MLP_ONLY_CFG, + "nvfp4_mlp_wo": _cfg_mod.NVFP4_MLP_WEIGHT_ONLY_CFG, + "nvfp4_omlp_only": _cfg_mod.NVFP4_OMLP_ONLY_CFG, + # W4A8 variants + "w4a8_awq": _cfg_mod.W4A8_AWQ_BETA_CFG, + "w4a8_nvfp4_fp8": _cfg_mod.W4A8_NVFP4_FP8_CFG, + "w4a8_mxfp4_fp8": _cfg_mod.W4A8_MXFP4_FP8_CFG, + # MX formats + "mxfp8": _cfg_mod.MXFP8_DEFAULT_CFG, + "mxfp6": _cfg_mod.MXFP6_DEFAULT_CFG, + "mxfp4": _cfg_mod.MXFP4_DEFAULT_CFG, + "mxint8": _cfg_mod.MXINT8_DEFAULT_CFG, + "mxfp4_mlp_wo": _cfg_mod.MXFP4_MLP_WEIGHT_ONLY_CFG, + # Mamba MOE + "mamba_moe_fp8_aggressive": _cfg_mod.MAMBA_MOE_FP8_AGGRESSIVE_CFG, + "mamba_moe_fp8_conservative": _cfg_mod.MAMBA_MOE_FP8_CONSERVATIVE_CFG, + "mamba_moe_nvfp4_aggressive": _cfg_mod.MAMBA_MOE_NVFP4_AGGRESSIVE_CFG, + "mamba_moe_nvfp4_conservative": _cfg_mod.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG, + } + except AttributeError: + # Some constants have been removed — this tier is no longer usable + return None + + +def _load_registry() -> dict[str, dict[str, Any]]: # pragma: no cover + """Lazily load preset configs with tiered fallback.""" + global _PRESET_REGISTRY, _PRESET_SOURCE + if _PRESET_REGISTRY is not None: + return _PRESET_REGISTRY + + # Tier 1: Load from YAML fragments (1a: load_config, 1b: our own loader) + registry = _try_load_yaml_registry() + if registry is not None: + # Fill in Tier-2-only presets from Python constants + _fill_tier2_only_presets(registry) + _PRESET_REGISTRY = registry + _PRESET_SOURCE = "yaml" + logger.debug("Loaded %d presets from YAML fragments", len(registry)) + return _PRESET_REGISTRY + + # Tier 2: Live import from Python constants (deprecated, will be removed) + registry = _try_load_python_registry() + if registry is not None: + _PRESET_REGISTRY = registry + _PRESET_SOURCE = "live" + logger.debug("Loaded %d presets from Python constants (deprecated path)", len(registry)) + return _PRESET_REGISTRY + + raise RuntimeError( + "Cannot load preset registry. Neither modelopt_recipes YAML fragments " + "nor modelopt.torch.quantization.config Python constants are available. " + "Run 'pip install -e .' from the Model-Optimizer repo root." + ) + + +def get_preset(name: str) -> dict[str, Any]: + """Return a deep copy of the preset config dict.""" + registry = _load_registry() + if name not in registry: + available = sorted(registry.keys()) + raise KeyError(f"Unknown preset '{name}'. Available: {available}") + return copy.deepcopy(registry[name]) + + +def get_preset_info(name: str) -> dict[str, str]: + """Return metadata for a preset (description, recipe_type). + + Metadata is loaded from recipe.yml in the composed recipe directory. + Returns empty dict if no metadata is available. + """ + _load_registry() # Ensure metadata is loaded + return _PRESET_METADATA.get(name, {}) + + +def get_preset_source() -> str: + """Return 'yaml' or 'live' indicating which tier is active.""" + _load_registry() + return _PRESET_SOURCE + + +def list_presets() -> list[str]: + """Return sorted list of available preset names.""" + return sorted(_load_registry().keys()) diff --git a/modelopt/torch/recipes/schema/resolver.py b/modelopt/torch/recipes/schema/resolver.py new file mode 100644 index 000000000..e2d8bdffc --- /dev/null +++ b/modelopt/torch/recipes/schema/resolver.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resolver: transforms validated RecipeConfig into mtq-compatible config dicts. + +The resolver is the core translation layer between human-readable YAML +and the internal config dicts that mtq.quantize() and mtq.auto_quantize() accept. +""" + +from __future__ import annotations + +import copy +from typing import Any + +from .formats import get_format, get_kv_format +from .models import ( + AlgorithmConfig, + AutoQuantizeSection, + OverrideEntry, + QuantizationSection, + QuantizerSpec, + RecipeConfig, +) +from .presets import get_preset + +# Fallback disabled quantizer patterns — used when PR #1000's load_config is not available. +# Matches PR #1000's configs/ptq/base.yml (torch.nn. prefix, not nn.). +_FALLBACK_DISABLED_QUANTIZER_CFG: dict[str, Any] = { + "torch.nn.BatchNorm1d": {"*": {"enable": False}}, + "torch.nn.BatchNorm2d": {"*": {"enable": False}}, + "torch.nn.BatchNorm3d": {"*": {"enable": False}}, + "torch.nn.LeakyReLU": {"*": {"enable": False}}, + "*lm_head*": {"enable": False}, + "*proj_out.*": {"enable": False}, + "*block_sparse_moe.gate*": {"enable": False}, + "*router*": {"enable": False}, + "*mlp.gate.*": {"enable": False}, + "*mlp.shared_expert_gate.*": {"enable": False}, + "*linear_attn.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, + "*output_layer*": {"enable": False}, + "output.*": {"enable": False}, + "default": {"enable": False}, +} + + +def _load_disabled_quantizer_cfg() -> dict[str, Any]: + """Load disabled quantizer config: prefer PR #1000's base.yml, fall back to inline.""" + from modelopt.torch.recipes.utils import try_import_load_config + + load_config = try_import_load_config() + if load_config is not None: + try: + cfg = load_config("configs/ptq/base") + return cfg["quant_cfg"] + except (ValueError, KeyError, TypeError): + pass + return _FALLBACK_DISABLED_QUANTIZER_CFG + + +_DEFAULT_DISABLED_QUANTIZER_CFG: dict[str, Any] = _load_disabled_quantizer_cfg() + + +def _update_quant_cfg_with_kv_cache( + quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any] +) -> dict[str, Any]: + """Merge KV cache quantizer patterns into the main config. + + Equivalent to modelopt.torch.quantization.utils.update_quant_cfg_with_kv_cache_quant(). + + Uses dict.update() so the user's explicit kv_cache section in the recipe + always wins over any KV patterns baked into the preset (e.g., from Tier 1 + YAML presets that include kv_quant.yml). + """ + quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg", {"default": {"enable": False}}) + quant_cfg["quant_cfg"].update(kv_cache_quant_cfg) + if not quant_cfg.get("algorithm"): + quant_cfg["algorithm"] = "max" + return quant_cfg + + +def resolve_recipe(recipe: RecipeConfig) -> dict[str, Any]: + """Resolve a RecipeConfig into output dict(s) for mtq APIs. + + Returns a dict with keys: + - "quantize_config": config dict for mtq.quantize() (if quantization section present) + - "auto_quantize_kwargs": kwargs dict for mtq.auto_quantize() (if auto_quantize section) + - "calibration": calibration params dict (if specified) + - "export": export params dict (if specified) + """ + result: dict[str, Any] = {} + + if recipe.quantization: + result["quantize_config"] = _resolve_quantization(recipe.quantization) + if recipe.quantization.calibration: + result["calibration"] = recipe.quantization.calibration.model_dump() + + if recipe.auto_quantize: + result["auto_quantize_kwargs"] = _resolve_auto_quantize(recipe.auto_quantize) + if recipe.auto_quantize.calibration: + result["calibration"] = recipe.auto_quantize.calibration.model_dump() + + if recipe.export: + result["export"] = recipe.export.model_dump() + + return result + + +def _resolve_quantization(section: QuantizationSection) -> dict[str, Any]: + """Produce the config dict for mtq.quantize().""" + # Step 1: Start from preset or build from scratch + if section.preset: + config = get_preset(section.preset) + else: + quant_cfg: dict[str, Any] = {} + if section.weights: + quant_cfg["*weight_quantizer"] = _resolve_quantizer_spec(section.weights, "weight") + if section.activations: + quant_cfg["*input_quantizer"] = _resolve_quantizer_spec( + section.activations, "activation" + ) + quant_cfg.update(_DEFAULT_DISABLED_QUANTIZER_CFG) + config = {"quant_cfg": quant_cfg, "algorithm": "max"} + + # Step 2: Apply algorithm override + if section.algorithm is not None: + if isinstance(section.algorithm, str): + config["algorithm"] = section.algorithm + elif isinstance(section.algorithm, AlgorithmConfig): + algo_dict = section.algorithm.model_dump(exclude_none=True) + config["algorithm"] = algo_dict + + # Step 3: Apply overrides + for override in section.overrides: + _apply_override(config["quant_cfg"], override) + + # Step 4: Apply disabled_patterns + for pattern in section.disabled_patterns: + config["quant_cfg"][pattern] = {"enable": False} + + # Step 5: Merge KV cache + if section.kv_cache: + kv_cfg = copy.deepcopy(get_kv_format(section.kv_cache.format)) + config = _update_quant_cfg_with_kv_cache(config, kv_cfg) + + return config + + +def _resolve_quantizer_spec(spec: QuantizerSpec, target: str) -> dict[str, Any] | list[dict]: + """Convert a QuantizerSpec to quantizer attribute dict(s). + + Args: + spec: The quantizer specification from the recipe YAML. + target: "weight" or "activation" — used to pick format defaults. + + Returns: + A dict of quantizer attributes, or a list of dicts for staged quantization. + """ + if spec.stages: + return [_resolve_single_quantizer(stage, target) for stage in spec.stages] + return _resolve_single_quantizer(spec, target) + + +def _resolve_single_quantizer(spec: QuantizerSpec, target: str) -> dict[str, Any]: + """Resolve a single (non-staged) quantizer spec to attribute dict.""" + result: dict[str, Any] = {} + + if spec.format: + fmt = get_format(spec.format) + result.update(copy.deepcopy(fmt[target])) + + # Expert-mode overrides + if spec.num_bits is not None: + result["num_bits"] = spec.num_bits + if spec.axis is not None: + result["axis"] = spec.axis + if spec.block_sizes is not None: + result["block_sizes"] = _resolve_block_sizes(spec.block_sizes) + if not spec.enable: + result["enable"] = False + + return result + + +def _resolve_block_sizes(bs: dict[str, Any]) -> dict: + """Convert block_sizes from YAML-friendly format to internal format. + + YAML uses string keys like "last_dim"; internal uses integer keys like -1. + Also supports passing through raw dicts with integer keys directly. + """ + result: dict = {} + key_map = {"last_dim": -1, "second_last_dim": -2} + + for k, v in bs.items(): + if k in key_map: + result[key_map[k]] = v + elif k == "scale_bits": + result["scale_bits"] = v + else: + # Pass through: "type", "scale_bits" (tuple), integer keys, etc. + try: + result[int(k)] = v + except (ValueError, TypeError): + result[k] = v + + return result + + +def _apply_override(quant_cfg: dict, override: OverrideEntry) -> None: + """Apply a single override entry to the quant_cfg dict. + + Pattern overrides merge into existing entries (preserving preset values). + Module-class overrides also merge to avoid dropping defaults like disabled BatchNorm. + """ + if override.pattern: + # Start from existing entry if present (to preserve preset values for merging) + entry: dict[str, Any] = copy.deepcopy(quant_cfg.get(override.pattern, {})) + if override.enable is not None: + entry["enable"] = override.enable + if override.format: + fmt = get_format(override.format) + entry.update(copy.deepcopy(fmt["weight"])) + if override.weights: + entry.update(_resolve_single_quantizer(override.weights, "weight")) + if override.activations: + entry.update(_resolve_single_quantizer(override.activations, "activation")) + if override.scale_type: + # Merge scale_type into block_sizes.type, preserving existing block_sizes + bs = entry.get("block_sizes", {}) + bs["type"] = override.scale_type + entry["block_sizes"] = bs + if override.num_bits is not None: + entry["num_bits"] = override.num_bits + if override.axis is not None: + entry["axis"] = override.axis + quant_cfg[override.pattern] = entry + + elif override.module_class: + # Merge into existing entry to preserve defaults (e.g., disabled BatchNorm) + mc_cfg: dict[str, Any] = copy.deepcopy(quant_cfg.get(override.module_class, {})) + if override.weights: + mc_cfg["*weight_quantizer"] = _resolve_quantizer_spec(override.weights, "weight") + if override.activations: + mc_cfg["*input_quantizer"] = _resolve_quantizer_spec(override.activations, "activation") + if override.enable is not None and not override.weights and not override.activations: + mc_cfg = {"*": {"enable": override.enable}} + quant_cfg[override.module_class] = mc_cfg + + +def _resolve_auto_quantize(section: AutoQuantizeSection) -> dict[str, Any]: + """Produce kwargs dict for mtq.auto_quantize().""" + format_configs = [get_preset(fmt_entry.preset) for fmt_entry in section.formats] + + kwargs: dict[str, Any] = { + "constraints": {"effective_bits": section.effective_bits}, + "quantization_formats": format_configs, + "num_calib_steps": section.num_calib_steps, + "num_score_steps": section.num_score_steps, + "method": section.method, + } + + if section.disabled_patterns: + kwargs["disabled_layers"] = section.disabled_patterns + + if section.kv_cache: + kv_cfg = copy.deepcopy(get_kv_format(section.kv_cache.format)) + # Apply KV cache quantization to each candidate format + for fmt_cfg in format_configs: + _update_quant_cfg_with_kv_cache(fmt_cfg, kv_cfg) + + return kwargs diff --git a/modelopt/torch/recipes/utils.py b/modelopt/torch/recipes/utils.py new file mode 100644 index 000000000..2f6c04d51 --- /dev/null +++ b/modelopt/torch/recipes/utils.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for the recipe system.""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any + +import yaml + +if TYPE_CHECKING: + from pathlib import Path + + +def deep_merge(base: dict, override: dict) -> dict: + """Recursively merge override into base dict (like OmegaConf.merge but lightweight).""" + result = copy.deepcopy(base) + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = copy.deepcopy(value) + return result + + +def load_yaml_with_bases(yaml_path: Path, recipes_root: Path) -> dict[str, Any]: + """Load a YAML file resolving __base__ inheritance. + + Implements the same __base__ merging as PR #1000's load_config(): + reads __base__ list, recursively loads each base, merges in order. + """ + with open(yaml_path) as f: + data = yaml.safe_load(f) or {} + + bases = data.pop("__base__", []) + if not bases: + return data + + # Resolve each base file (path without .yml extension) + merged: dict[str, Any] = {} + for base_ref in bases: + base_path = recipes_root / f"{base_ref}.yml" + if not base_path.is_file(): + base_path = recipes_root / f"{base_ref}.yaml" + if not base_path.is_file(): + raise FileNotFoundError(f"Base config not found: {base_ref} (tried .yml and .yaml)") + base_data = load_yaml_with_bases(base_path, recipes_root) + merged = deep_merge(merged, base_data) + + # Current file overrides bases + merged = deep_merge(merged, data) + return merged + + +def make_serializable(obj: Any) -> Any: + """Convert tuples and other non-JSON-safe types for serialization/display. + + Recursively converts dicts (with any key type), lists, and tuples into + JSON-compatible structures. Dict keys are stringified. + """ + if isinstance(obj, dict): + return {str(k): make_serializable(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_serializable(item) for item in obj] + elif isinstance(obj, (int, float, str, bool, type(None))): + return obj + return str(obj) + + +def try_import_load_config(): + """Try to import PR #1000's load_config function. + + Returns the function if available, None otherwise. This is the forward-compatible + import point for load_config() from modelopt.torch.opt.config. + """ + try: + from modelopt.torch.opt.config import load_config # type: ignore[attr-defined] + + return load_config + except (ImportError, ModuleNotFoundError): + return None diff --git a/pyproject.toml b/pyproject.toml index 92ba9cb5e..888d4434b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ # modelopt.torch "pulp", "pydantic>=2.0", + "pyyaml>=5.0", "regex", "rich", "safetensors", diff --git a/tests/unit/torch/recipes/__init__.py b/tests/unit/torch/recipes/__init__.py new file mode 100644 index 000000000..47f1c65a1 --- /dev/null +++ b/tests/unit/torch/recipes/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/torch/recipes/conftest.py b/tests/unit/torch/recipes/conftest.py new file mode 100644 index 000000000..b779d25b5 --- /dev/null +++ b/tests/unit/torch/recipes/conftest.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test fixtures.""" + +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "requires_modelopt: needs nvidia-modelopt installed") + + +def pytest_collection_modifyitems(config, items): + try: + import modelopt.torch.quantization.config # noqa: F401 + + return # modelopt available, run all tests + except ModuleNotFoundError: + pass + skip = pytest.mark.skip(reason="nvidia-modelopt not installed") + for item in items: + if "requires_modelopt" in item.keywords: + item.add_marker(skip) diff --git a/tests/unit/torch/recipes/fixtures/auto_quantize_basic.yaml b/tests/unit/torch/recipes/fixtures/auto_quantize_basic.yaml new file mode 100644 index 000000000..31a13f764 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/auto_quantize_basic.yaml @@ -0,0 +1,27 @@ +# Auto-quantize resolves to kwargs dict +_test: + description: "Auto-quantize resolves to kwargs with constraints and formats" + + validate: + expect_success: true + + resolve: + check_has_key: auto_quantize_kwargs + check_no_key: quantize_config + check_config: + auto_quantize_kwargs: + constraints: + effective_bits: 4.5 + num_calib_steps: 256 + num_score_steps: 64 + method: gradient + check_format_count: 2 + +version: "1.0" +auto_quantize: + effective_bits: 4.5 + formats: + - preset: fp8 + - preset: int8 + num_calib_steps: 256 + num_score_steps: 64 diff --git a/tests/unit/torch/recipes/fixtures/auto_quantize_disabled.yaml b/tests/unit/torch/recipes/fixtures/auto_quantize_disabled.yaml new file mode 100644 index 000000000..5282bee3f --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/auto_quantize_disabled.yaml @@ -0,0 +1,20 @@ +# Auto-quantize maps disabled_patterns to disabled_layers +_test: + description: "Auto-quantize disabled_patterns map to disabled_layers" + + validate: + expect_success: true + + resolve: + check_config: + auto_quantize_kwargs: + disabled_layers: ["*lm_head*", "*embed*"] + +version: "1.0" +auto_quantize: + effective_bits: 4.5 + formats: + - preset: fp8 + disabled_patterns: + - "*lm_head*" + - "*embed*" diff --git a/tests/unit/torch/recipes/fixtures/auto_quantize_kv_cache.yaml b/tests/unit/torch/recipes/fixtures/auto_quantize_kv_cache.yaml new file mode 100644 index 000000000..dbd33b60b --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/auto_quantize_kv_cache.yaml @@ -0,0 +1,22 @@ +# Auto-quantize with kv_cache and disabled_patterns (existing test coverage) +_test: + description: "Auto-quantize with KV cache and disabled_patterns" + + validate: + expect_success: true + + resolve: + check_has_key: auto_quantize_kwargs + check_config: + auto_quantize_kwargs: + disabled_layers: ["*lm_head*"] + +version: "1.0" +auto_quantize: + effective_bits: 4.5 + formats: + - preset: fp8 + kv_cache: + format: fp8 + disabled_patterns: + - "*lm_head*" diff --git a/tests/unit/torch/recipes/fixtures/calibration_included.yaml b/tests/unit/torch/recipes/fixtures/calibration_included.yaml new file mode 100644 index 000000000..916a24828 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/calibration_included.yaml @@ -0,0 +1,22 @@ +# Calibration section resolved and included +_test: + description: "Calibration section resolves to result dict" + + validate: + expect_success: true + + resolve: + check_has_key: calibration + check_config: + calibration: + dataset: wikitext + num_samples: 128 + max_sequence_length: 2048 + +version: "1.0" +quantization: + preset: fp8 + calibration: + dataset: wikitext + num_samples: 128 + max_sequence_length: 2048 diff --git a/tests/unit/torch/recipes/fixtures/error_exclusive_sections.yaml b/tests/unit/torch/recipes/fixtures/error_exclusive_sections.yaml new file mode 100644 index 000000000..259e0b7b4 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_exclusive_sections.yaml @@ -0,0 +1,14 @@ +# Schema validation error: quantization + auto_quantize mutually exclusive +_test: + description: "Schema error: quantization and auto_quantize are mutually exclusive" + + validate: + expect_error: "mutually exclusive" + +version: "1.0" +quantization: + preset: fp8 +auto_quantize: + effective_bits: 4.5 + formats: + - preset: fp8 diff --git a/tests/unit/torch/recipes/fixtures/error_preset_and_custom.yaml b/tests/unit/torch/recipes/fixtures/error_preset_and_custom.yaml new file mode 100644 index 000000000..37390de86 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_preset_and_custom.yaml @@ -0,0 +1,12 @@ +# Schema validation error: preset + weights mutually exclusive +_test: + description: "Schema error: preset and custom weights are mutually exclusive" + + validate: + expect_error: "Cannot specify both" + +version: "1.0" +quantization: + preset: fp8 + weights: + format: int8 diff --git a/tests/unit/torch/recipes/fixtures/error_qat_no_training.yaml b/tests/unit/torch/recipes/fixtures/error_qat_no_training.yaml new file mode 100644 index 000000000..878583eca --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_qat_no_training.yaml @@ -0,0 +1,11 @@ +# Schema validation error: QAT mode requires training config +_test: + description: "Schema error: QAT mode requires training configuration" + + validate: + expect_error: "QAT mode requires" + +version: "1.0" +quantization: + mode: qat + preset: fp8 diff --git a/tests/unit/torch/recipes/fixtures/error_sparsity_over_one.yaml b/tests/unit/torch/recipes/fixtures/error_sparsity_over_one.yaml new file mode 100644 index 000000000..48f978348 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_sparsity_over_one.yaml @@ -0,0 +1,11 @@ +# Schema validation error: sparsity > 1.0 out of range +_test: + description: "Schema error: sparsity 1.5 rejected (must be <= 1.0)" + + validate: + expect_error: "Sparsity must be in" + +version: "1.0" +sparsity: + method: sparse_gpt + sparsity: 1.5 diff --git a/tests/unit/torch/recipes/fixtures/error_sparsity_zero.yaml b/tests/unit/torch/recipes/fixtures/error_sparsity_zero.yaml new file mode 100644 index 000000000..a36815b8e --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_sparsity_zero.yaml @@ -0,0 +1,11 @@ +# Schema validation error: sparsity = 0.0 out of range +_test: + description: "Schema error: sparsity 0.0 rejected (must be > 0)" + + validate: + expect_error: "Sparsity must be in" + +version: "1.0" +sparsity: + method: sparse_gpt + sparsity: 0.0 diff --git a/tests/unit/torch/recipes/fixtures/error_unknown_format.yaml b/tests/unit/torch/recipes/fixtures/error_unknown_format.yaml new file mode 100644 index 000000000..b4cac8f58 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_unknown_format.yaml @@ -0,0 +1,16 @@ +# Resolve error: unknown format name in weights +_test: + description: "Resolve error: unknown format raises KeyError" + + validate: + expect_success: true + + resolve: + expect_error: "Unknown format" + +version: "1.0" +quantization: + weights: + format: nonexistent_format + activations: + enable: false diff --git a/tests/unit/torch/recipes/fixtures/error_unknown_kv_format.yaml b/tests/unit/torch/recipes/fixtures/error_unknown_kv_format.yaml new file mode 100644 index 000000000..258507959 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/error_unknown_kv_format.yaml @@ -0,0 +1,15 @@ +# Resolve error: unknown KV cache format name +_test: + description: "Resolve error: unknown KV cache format raises KeyError" + + validate: + expect_success: true + + resolve: + expect_error: "Unknown KV cache format" + +version: "1.0" +quantization: + preset: fp8 + kv_cache: + format: nonexistent_kv_format diff --git a/tests/unit/torch/recipes/fixtures/export_included.yaml b/tests/unit/torch/recipes/fixtures/export_included.yaml new file mode 100644 index 000000000..2666d99ed --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/export_included.yaml @@ -0,0 +1,22 @@ +# Export section resolved and included +_test: + description: "Export section resolves to result dict" + + validate: + expect_success: true + + resolve: + check_has_key: export + check_config: + export: + format: hf + output_dir: ./my_output + tensor_parallel: 4 + +version: "1.0" +quantization: + preset: fp8 +export: + format: hf + output_dir: ./my_output + tensor_parallel: 4 diff --git a/tests/unit/torch/recipes/fixtures/no_export.yaml b/tests/unit/torch/recipes/fixtures/no_export.yaml new file mode 100644 index 000000000..cebf0928f --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/no_export.yaml @@ -0,0 +1,14 @@ +# Recipe without export — no export key in result +_test: + description: "No export section means no export key in result" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_no_key: export + +version: "1.0" +quantization: + preset: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_algorithm_dict.yaml b/tests/unit/torch/recipes/fixtures/ptq_algorithm_dict.yaml new file mode 100644 index 000000000..07dfc2f38 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_algorithm_dict.yaml @@ -0,0 +1,20 @@ +# Algorithm override as dict with method + extra fields +_test: + description: "Algorithm dict override with method and alpha_step" + + validate: + expect_success: true + + resolve: + check_config: + quantize_config: + algorithm: + method: awq_lite + alpha_step: 0.1 + +version: "1.0" +quantization: + preset: fp8 + algorithm: + method: awq_lite + alpha_step: 0.1 diff --git a/tests/unit/torch/recipes/fixtures/ptq_algorithm_string.yaml b/tests/unit/torch/recipes/fixtures/ptq_algorithm_string.yaml new file mode 100644 index 000000000..682f592e8 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_algorithm_string.yaml @@ -0,0 +1,14 @@ +# Algorithm override as plain string +_test: + description: "Algorithm string override replaces preset algorithm" + + validate: + expect_success: true + + resolve: + check_algorithm: awq_lite + +version: "1.0" +quantization: + preset: fp8 + algorithm: awq_lite diff --git a/tests/unit/torch/recipes/fixtures/ptq_all_core_formats_fp8.yaml b/tests/unit/torch/recipes/fixtures/ptq_all_core_formats_fp8.yaml new file mode 100644 index 000000000..a50026f61 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_all_core_formats_fp8.yaml @@ -0,0 +1,17 @@ +# Verify FP8 format resolves with weight and activation quantizers +_test: + description: "FP8 format has weight and activation quantizer configs" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_quant_cfg_has_keys: ["*weight_quantizer", "*input_quantizer"] + +version: "1.0" +quantization: + weights: + format: fp8 + activations: + format: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_block_sizes_passthrough.yaml b/tests/unit/torch/recipes/fixtures/ptq_block_sizes_passthrough.yaml new file mode 100644 index 000000000..b9fc5bb99 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_block_sizes_passthrough.yaml @@ -0,0 +1,22 @@ +# Block sizes with integer string keys and type passthrough +_test: + description: "Block sizes integer string keys and type field pass through" + + validate: + expect_success: true + + resolve: + check_block_sizes: + path: quantize_config.quant_cfg.*weight_quantizer.block_sizes + expected: + type: dynamic + 0: 128 + +version: "1.0" +quantization: + weights: + block_sizes: + type: dynamic + "0": 128 + activations: + enable: false diff --git a/tests/unit/torch/recipes/fixtures/ptq_block_sizes_string_keys.yaml b/tests/unit/torch/recipes/fixtures/ptq_block_sizes_string_keys.yaml new file mode 100644 index 000000000..d581d8a3e --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_block_sizes_string_keys.yaml @@ -0,0 +1,24 @@ +# Block sizes with string dimension keys +_test: + description: "Block sizes string keys (last_dim, second_last_dim) resolve to integers" + + validate: + expect_success: true + + resolve: + check_block_sizes: + path: quantize_config.quant_cfg.*weight_quantizer.block_sizes + expected: + -1: 16 + -2: 32 + +version: "1.0" +quantization: + weights: + format: nvfp4 + block_sizes: + last_dim: 16 + second_last_dim: 32 + scale_bits: [4, 3] + activations: + enable: false diff --git a/tests/unit/torch/recipes/fixtures/ptq_custom_w8a8.yaml b/tests/unit/torch/recipes/fixtures/ptq_custom_w8a8.yaml new file mode 100644 index 000000000..d08eb3982 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_custom_w8a8.yaml @@ -0,0 +1,22 @@ +# Custom weights/activations without preset +_test: + description: "Custom INT8 weights + INT8 activations (no preset)" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_quant_cfg_has_keys: ["*weight_quantizer", "*input_quantizer"] + check_config: + quantize_config: + quant_cfg: + "*weight_quantizer": + num_bits: 8 + +version: "1.0" +quantization: + weights: + format: int8 + activations: + format: int8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_disabled_patterns.yaml b/tests/unit/torch/recipes/fixtures/ptq_disabled_patterns.yaml new file mode 100644 index 000000000..cf6a877a7 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_disabled_patterns.yaml @@ -0,0 +1,23 @@ +# Disabled patterns create enable=false entries +_test: + description: "disabled_patterns injects enable:false entries" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_config: + quantize_config: + quant_cfg: + "*layers.0*": + enable: false + "*layers.1*": + enable: false + +version: "1.0" +quantization: + preset: fp8 + disabled_patterns: + - "*layers.0*" + - "*layers.1*" diff --git a/tests/unit/torch/recipes/fixtures/ptq_format_int4.yaml b/tests/unit/torch/recipes/fixtures/ptq_format_int4.yaml new file mode 100644 index 000000000..601aed495 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_format_int4.yaml @@ -0,0 +1,17 @@ +# Verify INT4 format resolves (core format registry coverage) +_test: + description: "INT4 format resolves with weight quantizer" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_quant_cfg_has_keys: ["*weight_quantizer"] + +version: "1.0" +quantization: + weights: + format: int4 + activations: + enable: false diff --git a/tests/unit/torch/recipes/fixtures/ptq_format_nvfp4.yaml b/tests/unit/torch/recipes/fixtures/ptq_format_nvfp4.yaml new file mode 100644 index 000000000..06bf6a2b9 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_format_nvfp4.yaml @@ -0,0 +1,17 @@ +# Verify NVFP4 format resolves (core format registry coverage) +_test: + description: "NVFP4 format resolves with weight quantizer" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_quant_cfg_has_keys: ["*weight_quantizer"] + +version: "1.0" +quantization: + weights: + format: nvfp4 + activations: + enable: false diff --git a/tests/unit/torch/recipes/fixtures/ptq_fp8_preset.yaml b/tests/unit/torch/recipes/fixtures/ptq_fp8_preset.yaml new file mode 100644 index 000000000..bd0452ee3 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_fp8_preset.yaml @@ -0,0 +1,15 @@ +# FP8 PTQ — simplest preset recipe +_test: + description: "FP8 preset resolves to quant_cfg with max algorithm" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_no_key: auto_quantize_kwargs + check_algorithm: max + +version: "1.0" +quantization: + preset: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_kv_cache_fp8.yaml b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_fp8.yaml new file mode 100644 index 000000000..cdea21300 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_fp8.yaml @@ -0,0 +1,16 @@ +# KV cache adds quantizer patterns +_test: + description: "KV cache section adds bmm/kv quantizer patterns" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_kv_patterns_present: true + +version: "1.0" +quantization: + preset: fp8 + kv_cache: + format: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_kv_cache_preserves_algorithm.yaml b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_preserves_algorithm.yaml new file mode 100644 index 000000000..8f7e88972 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_preserves_algorithm.yaml @@ -0,0 +1,16 @@ +# KV cache merge does not overwrite explicit algorithm +_test: + description: "KV cache preserves explicit algorithm override" + + validate: + expect_success: true + + resolve: + check_algorithm: awq_lite + +version: "1.0" +quantization: + preset: fp8 + algorithm: awq_lite + kv_cache: + format: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_kv_cache_sets_algorithm.yaml b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_sets_algorithm.yaml new file mode 100644 index 000000000..2f45c8f79 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_kv_cache_sets_algorithm.yaml @@ -0,0 +1,18 @@ +# KV cache merge sets algorithm to 'max' when not specified +_test: + description: "KV cache sets default algorithm to max" + + validate: + expect_success: true + + resolve: + check_algorithm: max + +version: "1.0" +quantization: + weights: + format: int8 + activations: + format: int8 + kv_cache: + format: fp8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_override_format_num_bits.yaml b/tests/unit/torch/recipes/fixtures/ptq_override_format_num_bits.yaml new file mode 100644 index 000000000..706348dd2 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_override_format_num_bits.yaml @@ -0,0 +1,23 @@ +# Pattern override with format, num_bits, and axis +_test: + description: "Pattern override with format + num_bits + axis" + + validate: + expect_success: true + + resolve: + check_config: + quantize_config: + quant_cfg: + "*mlp*weight_quantizer": + num_bits: 4 + axis: 0 + +version: "1.0" +quantization: + preset: fp8 + overrides: + - pattern: "*mlp*weight_quantizer" + format: int4 + num_bits: 4 + axis: 0 diff --git a/tests/unit/torch/recipes/fixtures/ptq_override_module_class.yaml b/tests/unit/torch/recipes/fixtures/ptq_override_module_class.yaml new file mode 100644 index 000000000..93c106223 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_override_module_class.yaml @@ -0,0 +1,24 @@ +# Module-class override +_test: + description: "Module-class override adds weight quantizer for nn.Linear" + + validate: + expect_success: true + + resolve: + check_has_key: quantize_config + check_quant_cfg_has_keys: ["nn.Linear"] + check_config: + quantize_config: + quant_cfg: + "nn.Linear": + "*weight_quantizer": + num_bits: 8 + +version: "1.0" +quantization: + preset: fp8 + overrides: + - module_class: "nn.Linear" + weights: + format: int8 diff --git a/tests/unit/torch/recipes/fixtures/ptq_override_module_disable.yaml b/tests/unit/torch/recipes/fixtures/ptq_override_module_disable.yaml new file mode 100644 index 000000000..fbc2762b1 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_override_module_disable.yaml @@ -0,0 +1,21 @@ +# Module-class override with enable=false (no weights/activations) +_test: + description: "Module-class disable override uses wildcard pattern" + + validate: + expect_success: true + + resolve: + check_config: + quantize_config: + quant_cfg: + "nn.Embedding": + "*": + enable: false + +version: "1.0" +quantization: + preset: fp8 + overrides: + - module_class: "nn.Embedding" + enable: false diff --git a/tests/unit/torch/recipes/fixtures/ptq_scale_type_new_pattern.yaml b/tests/unit/torch/recipes/fixtures/ptq_scale_type_new_pattern.yaml new file mode 100644 index 000000000..7bc3e1681 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_scale_type_new_pattern.yaml @@ -0,0 +1,21 @@ +# scale_type on a new pattern creates fresh entry +_test: + description: "scale_type on new pattern creates entry with block_sizes.type" + + validate: + expect_success: true + + resolve: + check_config: + quantize_config: + quant_cfg: + "*self_attn*weight_quantizer": + block_sizes: + type: dynamic + +version: "1.0" +quantization: + preset: nvfp4_local_hessian + overrides: + - pattern: "*self_attn*weight_quantizer" + scale_type: dynamic diff --git a/tests/unit/torch/recipes/fixtures/ptq_scale_type_override.yaml b/tests/unit/torch/recipes/fixtures/ptq_scale_type_override.yaml new file mode 100644 index 000000000..fb636dc2a --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/ptq_scale_type_override.yaml @@ -0,0 +1,21 @@ +# scale_type shorthand merges into block_sizes.type +_test: + description: "scale_type override merges into block_sizes preserving existing values" + + validate: + expect_success: true + + resolve: + check_config: + quantize_config: + quant_cfg: + "*weight_quantizer": + block_sizes: + type: dynamic + +version: "1.0" +quantization: + preset: nvfp4_local_hessian + overrides: + - pattern: "*weight_quantizer" + scale_type: dynamic diff --git a/tests/unit/torch/recipes/fixtures/sparsity_boundary.yaml b/tests/unit/torch/recipes/fixtures/sparsity_boundary.yaml new file mode 100644 index 000000000..9314c1d13 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/sparsity_boundary.yaml @@ -0,0 +1,11 @@ +# Sparsity at boundary value 1.0 (valid) +_test: + description: "Sparsity boundary: 1.0 is valid" + + validate: + expect_success: true + +version: "1.0" +sparsity: + method: sparse_gpt + sparsity: 1.0 diff --git a/tests/unit/torch/recipes/fixtures/sparsity_valid.yaml b/tests/unit/torch/recipes/fixtures/sparsity_valid.yaml new file mode 100644 index 000000000..1c4fb0ca8 --- /dev/null +++ b/tests/unit/torch/recipes/fixtures/sparsity_valid.yaml @@ -0,0 +1,11 @@ +# Valid sparsity config +_test: + description: "Valid sparsity section (0.5)" + + validate: + expect_success: true + +version: "1.0" +sparsity: + method: sparse_gpt + sparsity: 0.5 diff --git a/tests/unit/torch/recipes/schema/__init__.py b/tests/unit/torch/recipes/schema/__init__.py new file mode 100644 index 000000000..a08b2c204 --- /dev/null +++ b/tests/unit/torch/recipes/schema/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/torch/recipes/schema/test_models.py b/tests/unit/torch/recipes/schema/test_models.py new file mode 100644 index 000000000..ab4f4a7dc --- /dev/null +++ b/tests/unit/torch/recipes/schema/test_models.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for recipe schema validation. + +Most schema tests are now YAML fixture-driven (see test_recipe_fixtures.py). +This file keeps tests that don't fit the fixture pattern (e.g., file iteration). +""" + +from pathlib import Path + +import yaml + +from modelopt.torch.recipes.schema.models import RecipeConfig + +FIXTURES_DIR = Path(__file__).parents[1] / "fixtures" + + +def test_all_fixture_recipes_parseable(): + """All non-error YAML fixtures parse as valid RecipeConfig.""" + for yaml_file in sorted(FIXTURES_DIR.glob("*.yaml")): + if yaml_file.stem.startswith("error_"): + continue + with open(yaml_file) as f: + raw = yaml.safe_load(f) + recipe_dict = {k: v for k, v in raw.items() if k != "_test"} + recipe = RecipeConfig.model_validate(recipe_dict) + assert recipe.version == "1.0", f"Failed: {yaml_file.name}" diff --git a/tests/unit/torch/recipes/schema/test_presets.py b/tests/unit/torch/recipes/schema/test_presets.py new file mode 100644 index 000000000..a88958848 --- /dev/null +++ b/tests/unit/torch/recipes/schema/test_presets.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for preset registry (schema/presets.py).""" + +from collections import Counter +from pathlib import Path + +import pytest +import yaml + +from modelopt.torch.recipes.schema.presets import ( + PRESET_YAML_MAP, + get_preset, + get_preset_info, + get_preset_source, + list_presets, + load_recipe_from_yaml, +) +from modelopt.torch.recipes.utils import load_yaml_with_bases + +# ── Helpers ── + + +def _write_yaml(path: Path, data: dict): + """Write a YAML file, creating parent dirs as needed.""" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=True) + + +# Common base fragment shared by all presets. +_COMMON_BASE = { + "quant_cfg": { + "default": {"enable": False}, + "*lm_head*": {"enable": False}, + "*output_layer*": {"enable": False}, + "*router*": {"enable": False}, + "*block_sparse_moe.gate*": {"enable": False}, + "*mlp.gate.*": {"enable": False}, + "*mlp.shared_expert_gate.*": {"enable": False}, + "*proj_out.*": {"enable": False}, + "*linear_attn.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, + "output.*": {"enable": False}, + "nn.BatchNorm1d": {"*": {"enable": False}}, + "nn.BatchNorm2d": {"*": {"enable": False}}, + "nn.BatchNorm3d": {"*": {"enable": False}}, + "nn.LeakyReLU": {"*": {"enable": False}}, + } +} + +_FP8_KV_FRAGMENT = { + "quant_cfg": { + "*k_proj*input_quantizer": {"axis": None, "num_bits": [4, 3]}, + "*v_proj*input_quantizer": {"axis": None, "num_bits": [4, 3]}, + } +} + + +def _setup_base_fragments(recipes_root: Path): + """Create the shared base and KV cache fragments.""" + _write_yaml(recipes_root / "fragments" / "base.yml", _COMMON_BASE) + _write_yaml(recipes_root / "fragments" / "fp8_kv.yml", _FP8_KV_FRAGMENT) + _write_yaml(recipes_root / "fragments" / "algo_max.yml", {"algorithm": "max"}) + _write_yaml( + recipes_root / "fragments" / "algo_awq_lite.yml", + {"algorithm": {"method": "awq_lite", "alpha_step": 0.1}}, + ) + _write_yaml( + recipes_root / "fragments" / "algo_smoothquant.yml", + {"algorithm": "smoothquant"}, + ) + + +def _setup_preset( + recipes_root: Path, + recipe_dir: str, + model_quant_bases: list[str], + model_quant_override: dict, + kv_quant_bases: list[str] | None = None, + kv_quant_override: dict | None = None, +): + """Create a composed recipe directory with model_quant.yml and optional kv_quant.yml.""" + recipe_path = recipes_root / recipe_dir + recipe_path.mkdir(parents=True, exist_ok=True) + + model_data = {"__base__": model_quant_bases} + model_data.update(model_quant_override) + _write_yaml(recipe_path / "model_quant.yml", model_data) + + if kv_quant_bases or kv_quant_override: + kv_data = {} + if kv_quant_bases: + kv_data["__base__"] = kv_quant_bases + if kv_quant_override: + kv_data.update(kv_quant_override) + _write_yaml(recipe_path / "kv_quant.yml", kv_data) + + +@pytest.fixture +def recipes_root(tmp_path): + """Set up a fake modelopt_recipes/ directory tree with 5 presets.""" + root = tmp_path + _setup_base_fragments(root) + + # fp8 + _write_yaml( + root / "fragments" / "fp8_quantizer.yml", + { + "quant_cfg": { + "*weight_quantizer": {"axis": None, "num_bits": [4, 3]}, + "*input_quantizer": {"axis": None, "num_bits": [4, 3]}, + } + }, + ) + _setup_preset( + root, + "general/ptq/fp8_default-fp8_kv", + model_quant_bases=["fragments/base", "fragments/fp8_quantizer", "fragments/algo_max"], + model_quant_override={}, + kv_quant_bases=["fragments/fp8_kv"], + ) + + # int8 + _write_yaml( + root / "fragments" / "int8_quantizer.yml", + { + "quant_cfg": { + "*weight_quantizer": {"axis": 0, "num_bits": 8}, + "*input_quantizer": {"axis": None, "num_bits": 8}, + } + }, + ) + _setup_preset( + root, + "general/ptq/int8_default-fp8_kv", + model_quant_bases=["fragments/base", "fragments/int8_quantizer", "fragments/algo_max"], + model_quant_override={}, + kv_quant_bases=["fragments/fp8_kv"], + ) + + # int4_awq + _write_yaml( + root / "fragments" / "int4_wo_quantizer.yml", + { + "quant_cfg": { + "*weight_quantizer": { + "enable": True, + "num_bits": 4, + "block_sizes": {"-1": 128, "type": "static"}, + }, + "*input_quantizer": {"enable": False}, + } + }, + ) + _setup_preset( + root, + "general/ptq/int4_awq-fp8_kv", + model_quant_bases=[ + "fragments/base", + "fragments/int4_wo_quantizer", + "fragments/algo_awq_lite", + ], + model_quant_override={}, + kv_quant_bases=["fragments/fp8_kv"], + ) + + # nvfp4 + _write_yaml( + root / "fragments" / "nvfp4_quantizer.yml", + { + "quant_cfg": { + "*weight_quantizer": { + "axis": None, + "enable": True, + "num_bits": [2, 1], + "block_sizes": {"-1": 16, "scale_bits": [4, 3], "type": "dynamic"}, + }, + "*input_quantizer": { + "axis": None, + "enable": True, + "num_bits": [2, 1], + "block_sizes": {"-1": 16, "scale_bits": [4, 3], "type": "dynamic"}, + }, + } + }, + ) + _setup_preset( + root, + "general/ptq/nvfp4_default-fp8_kv", + model_quant_bases=["fragments/base", "fragments/nvfp4_quantizer", "fragments/algo_max"], + model_quant_override={}, + kv_quant_bases=["fragments/fp8_kv"], + ) + + # int8_sq + _setup_preset( + root, + "general/ptq/int8_smoothquant-fp8_kv", + model_quant_bases=[ + "fragments/base", + "fragments/int8_quantizer", + "fragments/algo_smoothquant", + ], + model_quant_override={}, + kv_quant_bases=["fragments/fp8_kv"], + ) + + return root + + +# ── Unit tests: __base__ inheritance ── + + +class TestYamlBaseInheritance: + def test_base_resolution(self, tmp_path): + """Test __base__ inheritance with temp YAML files.""" + base_yml = tmp_path / "configs" / "base.yml" + base_yml.parent.mkdir(parents=True) + base_yml.write_text( + yaml.dump({"quant_cfg": {"default": {"enable": False}, "*lm_head*": {"enable": False}}}) + ) + quant_yml = tmp_path / "configs" / "fp8.yml" + quant_yml.write_text( + yaml.dump( + { + "quant_cfg": { + "*weight_quantizer": {"num_bits": [4, 3], "axis": None}, + "*input_quantizer": {"num_bits": [4, 3], "axis": None}, + } + } + ) + ) + algo_yml = tmp_path / "configs" / "algo_max.yml" + algo_yml.write_text(yaml.dump({"algorithm": "max"})) + + recipe_yml = tmp_path / "recipe.yml" + recipe_yml.write_text( + yaml.dump({"__base__": ["configs/base", "configs/fp8", "configs/algo_max"]}) + ) + + result = load_yaml_with_bases(recipe_yml, tmp_path) + assert result["algorithm"] == "max" + assert result["quant_cfg"]["default"] == {"enable": False} + assert result["quant_cfg"]["*weight_quantizer"]["num_bits"] == [4, 3] + + def test_base_override_order(self, tmp_path): + """Later bases override earlier bases.""" + _write_yaml(tmp_path / "a.yml", {"algorithm": "max", "x": 1}) + _write_yaml(tmp_path / "b.yml", {"algorithm": "awq_lite", "y": 2}) + _write_yaml(tmp_path / "c.yml", {"__base__": ["a", "b"]}) + result = load_yaml_with_bases(tmp_path / "c.yml", tmp_path) + assert result["algorithm"] == "awq_lite" + assert result["x"] == 1 + assert result["y"] == 2 + + def test_leaf_overrides_bases(self, tmp_path): + """Leaf file values override all bases.""" + _write_yaml(tmp_path / "base.yml", {"algorithm": "max", "extra": "keep"}) + _write_yaml( + tmp_path / "leaf.yml", + {"__base__": ["base"], "algorithm": "awq_lite"}, + ) + result = load_yaml_with_bases(tmp_path / "leaf.yml", tmp_path) + assert result["algorithm"] == "awq_lite" + assert result["extra"] == "keep" + + def test_deep_merge_preserves_nested(self, tmp_path): + """Deep merge combines nested quant_cfg entries.""" + _write_yaml( + tmp_path / "base.yml", + {"quant_cfg": {"*weight_quantizer": {"num_bits": 8}}}, + ) + _write_yaml( + tmp_path / "extra.yml", + {"quant_cfg": {"*weight_quantizer": {"axis": 0}}}, + ) + _write_yaml(tmp_path / "composed.yml", {"__base__": ["base", "extra"]}) + result = load_yaml_with_bases(tmp_path / "composed.yml", tmp_path) + wq = result["quant_cfg"]["*weight_quantizer"] + assert wq["num_bits"] == 8 + assert wq["axis"] == 0 + + def test_missing_base_raises(self, tmp_path): + recipe_yml = tmp_path / "recipe.yml" + recipe_yml.write_text(yaml.dump({"__base__": ["nonexistent"]})) + with pytest.raises(FileNotFoundError, match="nonexistent"): + load_yaml_with_bases(recipe_yml, tmp_path) + + +# ── E2E tests: fake modelopt_recipes/ → load_recipe_from_yaml ── + + +class TestYamlPresetE2E: + def test_fp8_loads_correctly(self, recipes_root): + """FP8 preset resolves __base__ chain and produces expected config.""" + config = load_recipe_from_yaml("general/ptq/fp8_default-fp8_kv", recipes_root) + assert config["algorithm"] == "max" + qcfg = config["quant_cfg"] + assert qcfg["default"] == {"enable": False} + assert qcfg["*lm_head*"] == {"enable": False} + assert qcfg["*weight_quantizer"]["num_bits"] == [4, 3] + assert qcfg["*input_quantizer"]["num_bits"] == [4, 3] + assert "*k_proj*input_quantizer" in qcfg + + def test_int8_loads_correctly(self, recipes_root): + """INT8 preset: integer num_bits, per-axis weights.""" + config = load_recipe_from_yaml("general/ptq/int8_default-fp8_kv", recipes_root) + assert config["algorithm"] == "max" + qcfg = config["quant_cfg"] + assert qcfg["*weight_quantizer"]["num_bits"] == 8 + assert qcfg["*weight_quantizer"]["axis"] == 0 + assert qcfg["*input_quantizer"]["num_bits"] == 8 + + def test_int4_awq_loads_correctly(self, recipes_root): + """INT4 AWQ: dict algorithm, weight-only, block_sizes.""" + config = load_recipe_from_yaml("general/ptq/int4_awq-fp8_kv", recipes_root) + assert config["algorithm"]["method"] == "awq_lite" + qcfg = config["quant_cfg"] + assert qcfg["*weight_quantizer"]["num_bits"] == 4 + assert qcfg["*input_quantizer"] == {"enable": False} + + def test_nvfp4_loads_correctly(self, recipes_root): + """NVFP4: list num_bits [2,1], block_sizes with scale_bits list.""" + config = load_recipe_from_yaml("general/ptq/nvfp4_default-fp8_kv", recipes_root) + assert config["algorithm"] == "max" + wq = config["quant_cfg"]["*weight_quantizer"] + assert wq["num_bits"] == [2, 1] + assert wq["block_sizes"]["scale_bits"] == [4, 3] + assert wq["block_sizes"]["type"] == "dynamic" + + def test_int8_sq_loads_correctly(self, recipes_root): + """INT8 SmoothQuant: same quantizers as INT8 but different algorithm.""" + config = load_recipe_from_yaml("general/ptq/int8_smoothquant-fp8_kv", recipes_root) + assert config["algorithm"] == "smoothquant" + assert config["quant_cfg"]["*weight_quantizer"]["num_bits"] == 8 + + def test_kv_cache_merging(self, recipes_root): + """KV cache entries are merged into model config.""" + config = load_recipe_from_yaml("general/ptq/fp8_default-fp8_kv", recipes_root) + qcfg = config["quant_cfg"] + assert "*k_proj*input_quantizer" in qcfg + assert "*v_proj*input_quantizer" in qcfg + + def test_yaml_preserves_lists(self, recipes_root): + """YAML-loaded configs preserve lists (no tuple conversion).""" + config = load_recipe_from_yaml("general/ptq/fp8_default-fp8_kv", recipes_root) + nb = config["quant_cfg"]["*weight_quantizer"]["num_bits"] + assert isinstance(nb, list), f"Expected list, got {type(nb)}" + + +# ── Preset map consistency ── + + +class TestPresetMapConsistency: + def test_yaml_map_covers_all_live_presets(self): + """Every live preset should have a corresponding YAML map entry.""" + live_presets = set(list_presets()) + yaml_presets = set(PRESET_YAML_MAP.keys()) + missing = live_presets - yaml_presets + assert not missing, f"Live presets missing from PRESET_YAML_MAP: {sorted(missing)}" + + def test_yaml_map_paths_follow_convention(self): + """All YAML map paths should follow general/ptq/-fp8_kv pattern.""" + for name, path in PRESET_YAML_MAP.items(): + assert path.startswith("general/ptq/"), f"'{name}' has non-standard path: {path}" + assert path.endswith("-fp8_kv"), f"'{name}' path doesn't end with -fp8_kv: {path}" + + def test_no_duplicate_paths(self): + """No two presets should map to the same directory (except known aliases).""" + counts = Counter(PRESET_YAML_MAP.values()) + allowed_aliases = {"general/ptq/nvfp4_awq_lite-fp8_kv": {"nvfp4_awq", "nvfp4_awq_lite"}} + for path, count in counts.items(): + if count > 1: + names = {n for n, p in PRESET_YAML_MAP.items() if p == path} + if path in allowed_aliases: + assert names == allowed_aliases[path], f"Unexpected aliases for {path}: {names}" + else: + pytest.fail(f"Unexpected duplicate path {path}: {names}") + + +# ── Public API ── + + +class TestYamlKvCacheConsistency: + """Verify that YAML-loaded presets produce the same structure as bundled ones.""" + + def test_yaml_preset_has_kv_patterns(self, recipes_root): + """YAML presets include KV patterns from kv_quant.yml.""" + config = load_recipe_from_yaml("general/ptq/fp8_default-fp8_kv", recipes_root) + qcfg = config["quant_cfg"] + # Should have KV patterns from kv_quant.yml + kv_keys = [k for k in qcfg if "k_proj" in k or "v_proj" in k] + assert len(kv_keys) > 0, "YAML preset should include KV patterns from kv_quant.yml" + + def test_yaml_preset_no_kv_without_kv_file(self, recipes_root): + """YAML preset without kv_quant.yml should NOT have KV patterns.""" + # Create a preset without kv_quant.yml + _setup_preset( + recipes_root, + "general/ptq/fp8_no_kv", + model_quant_bases=["fragments/base", "fragments/fp8_quantizer", "fragments/algo_max"], + model_quant_override={}, + ) + config = load_recipe_from_yaml("general/ptq/fp8_no_kv", recipes_root) + qcfg = config["quant_cfg"] + kv_keys = [k for k in qcfg if "k_proj" in k or "v_proj" in k] + assert len(kv_keys) == 0, "Preset without kv_quant.yml should not have KV patterns" + + +class TestPresetAPI: + def test_source_is_valid(self): + source = get_preset_source() + assert source in ("yaml", "live") + + def test_list_presets_nonempty(self): + presets = list_presets() + assert len(presets) > 0 + assert "fp8" in presets + + def test_get_preset_returns_deep_copy(self): + p1 = get_preset("fp8") + p2 = get_preset("fp8") + assert p1 == p2 + assert p1 is not p2 + + def test_unknown_preset_raises(self): + with pytest.raises(KeyError, match="Unknown preset"): + get_preset("nonexistent_preset_xyz") + + def test_get_preset_info_returns_dict(self): + info = get_preset_info("fp8") + assert isinstance(info, dict) + + def test_nvfp4_omlp_only_in_presets(self): + """New PR #1000 preset nvfp4_omlp_only should be in the map.""" + assert "nvfp4_omlp_only" in PRESET_YAML_MAP diff --git a/tests/unit/torch/recipes/schema/test_resolver.py b/tests/unit/torch/recipes/schema/test_resolver.py new file mode 100644 index 000000000..d6839f993 --- /dev/null +++ b/tests/unit/torch/recipes/schema/test_resolver.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for resolver (schema/resolver.py). + +Most resolver tests are now YAML fixture-driven (see test_recipe_fixtures.py). +This file keeps tests that require programmatic assertions not expressible in YAML. +""" diff --git a/tests/unit/torch/recipes/test_recipe_fixtures.py b/tests/unit/torch/recipes/test_recipe_fixtures.py new file mode 100644 index 000000000..d30164d62 --- /dev/null +++ b/tests/unit/torch/recipes/test_recipe_fixtures.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""YAML fixture-driven tests for the recipe system. + +Each YAML file in tests/unit/torch/recipes/fixtures/ is a complete test case: +the recipe itself plus a ``_test:`` metadata block with per-API assertions. + +Supported ``_test:`` sections: + + validate: # RecipeConfig.model_validate() + expect_success: true # validation should succeed + expect_error: "substring" # validation should fail with this message + + resolve: # resolve_recipe() + check_has_key: quantize_config # result must have this key + check_no_key: export # result must NOT have this key + check_algorithm: max # shorthand for result[quantize_config][algorithm] + check_quant_cfg_has_keys: [...] # keys present in quant_cfg + check_kv_patterns_present: true # bmm/kv quantizer patterns exist + check_format_count: N # number of quantization_formats + check_config: # nested dict value checks + quantize_config: + algorithm: max + check_block_sizes: # block_sizes specific checks + path: dotted.path.to.block_sizes + expected: {-1: 16} + + # Future: plan, execute sections added when those APIs land in PR +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from pydantic import ValidationError + +from modelopt.torch.recipes.schema.models import RecipeConfig +from modelopt.torch.recipes.schema.resolver import resolve_recipe + +# --------------------------------------------------------------------------- +# Discover YAML fixtures +# --------------------------------------------------------------------------- + +FIXTURES_DIR = Path(__file__).parent / "fixtures" + + +def _discover_fixtures(): + """Load all YAML files with _test metadata.""" + cases = [] + for path in sorted(FIXTURES_DIR.glob("*.yaml")): + with open(path) as f: + data = yaml.safe_load(f) + if data and data.get("_test"): + cases.append(pytest.param(data, id=path.stem)) + return cases + + +FIXTURE_CASES = _discover_fixtures() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_nested(d: dict, dotted_path: str): + """Navigate a dict by dotted path (e.g., 'quantize_config.quant_cfg.*weight_quantizer').""" + for key in dotted_path.split("."): + d = d[key] + return d + + +def _assert_dict_subset(actual: dict, expected: dict, path: str = ""): + """Assert that expected is a subset of actual (recursive).""" + for key, exp_val in expected.items(): + full_path = f"{path}.{key}" if path else str(key) + assert key in actual, f"Missing key '{full_path}' in {list(actual.keys())}" + act_val = actual[key] + if isinstance(exp_val, dict) and isinstance(act_val, dict): + _assert_dict_subset(act_val, exp_val, full_path) + elif isinstance(exp_val, list): + assert act_val == exp_val, f"At '{full_path}': expected {exp_val}, got {act_val}" + else: + assert act_val == exp_val, f"At '{full_path}': expected {exp_val}, got {act_val}" + + +# --------------------------------------------------------------------------- +# Validate tests — RecipeConfig.model_validate() +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fixture_data", FIXTURE_CASES) +def test_validate(fixture_data): + """Test schema validation for each YAML fixture.""" + test_meta = fixture_data["_test"] + validate_meta = test_meta.get("validate") + if validate_meta is None: + pytest.skip("No validate section in _test") + + recipe_dict = {k: v for k, v in fixture_data.items() if k != "_test"} + + expected_error = validate_meta.get("expect_error") + if expected_error: + with pytest.raises(ValidationError) as exc_info: + RecipeConfig.model_validate(recipe_dict) + assert expected_error in str(exc_info.value), ( + f"Expected '{expected_error}' in error, got: {exc_info.value}" + ) + return + + # expect_success (default) + recipe = RecipeConfig.model_validate(recipe_dict) + assert recipe is not None + + +# --------------------------------------------------------------------------- +# Resolve tests — resolve_recipe() +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fixture_data", FIXTURE_CASES) +def test_resolve(fixture_data): + """Test recipe resolution for each YAML fixture.""" + test_meta = fixture_data["_test"] + resolve_meta = test_meta.get("resolve") + if resolve_meta is None: + pytest.skip("No resolve section in _test") + + # Validation errors don't reach resolver + validate_meta = test_meta.get("validate", {}) + if validate_meta.get("expect_error"): + pytest.skip("Validation error fixture — no resolve") + + recipe_dict = {k: v for k, v in fixture_data.items() if k != "_test"} + recipe = RecipeConfig.model_validate(recipe_dict) + + # resolve.expect_error — resolver should raise + expected_error = resolve_meta.get("expect_error") + if expected_error: + with pytest.raises(Exception) as exc_info: + resolve_recipe(recipe) + assert expected_error in str(exc_info.value), ( + f"Expected '{expected_error}' in error, got: {exc_info.value}" + ) + return + + result = resolve_recipe(recipe) + + # check_has_key + has_key = resolve_meta.get("check_has_key") + if has_key: + assert has_key in result, f"Expected key '{has_key}' in result, got {list(result.keys())}" + + # check_no_key + no_key = resolve_meta.get("check_no_key") + if no_key: + assert no_key not in result, f"Key '{no_key}' should not be in result" + + # check_algorithm + check_algo = resolve_meta.get("check_algorithm") + if check_algo: + assert result["quantize_config"]["algorithm"] == check_algo, ( + f"Expected algorithm '{check_algo}', got {result['quantize_config']['algorithm']}" + ) + + # check_quant_cfg_has_keys + cfg_keys = resolve_meta.get("check_quant_cfg_has_keys") + if cfg_keys: + qcfg = result["quantize_config"]["quant_cfg"] + for key in cfg_keys: + assert key in qcfg, f"Expected '{key}' in quant_cfg, got {list(qcfg.keys())}" + + # check_kv_patterns_present + if resolve_meta.get("check_kv_patterns_present"): + qcfg = result["quantize_config"]["quant_cfg"] + kv_keys = [k for k in qcfg if "bmm_quantizer" in k or "kv" in k.lower()] + assert len(kv_keys) > 0, f"Expected KV cache patterns, got none in {list(qcfg.keys())}" + + # check_format_count + fmt_count = resolve_meta.get("check_format_count") + if fmt_count is not None: + formats = result["auto_quantize_kwargs"]["quantization_formats"] + assert len(formats) == fmt_count, f"Expected {fmt_count} formats, got {len(formats)}" + + # check_config (nested dict subset) + check_config = resolve_meta.get("check_config") + if check_config: + _assert_dict_subset(result, check_config) + + # check_block_sizes (special path-based check) + check_bs = resolve_meta.get("check_block_sizes") + if check_bs: + bs = _get_nested(result, check_bs["path"]) + for key, expected_val in check_bs["expected"].items(): + # YAML loads integer keys as int, but block_sizes may have int or str keys + actual_val = bs.get(key) + if actual_val is None: + actual_val = bs.get(str(key)) + assert actual_val == expected_val, ( + f"block_sizes[{key}]: expected {expected_val}, got {actual_val}" + ) diff --git a/tests/unit/torch/recipes/test_utils.py b/tests/unit/torch/recipes/test_utils.py new file mode 100644 index 000000000..951782462 --- /dev/null +++ b/tests/unit/torch/recipes/test_utils.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for recipes/utils.py.""" + +from modelopt.torch.recipes.utils import deep_merge, make_serializable + + +def test_make_serializable_tuples_to_lists(): + """Tuples and nested structures are converted to lists.""" + result = make_serializable({"a": (1, 2), "b": {"c": (3,)}, "d": [4, 5]}) + assert result == {"a": [1, 2], "b": {"c": [3]}, "d": [4, 5]} + + +def test_make_serializable_int_keys_to_strings(): + """Dict keys (including ints) are stringified.""" + result = make_serializable({-1: 16, "type": "dynamic"}) + assert result == {"-1": 16, "type": "dynamic"} + + +def test_make_serializable_primitives_passthrough(): + """Primitive types pass through unchanged.""" + assert make_serializable(42) == 42 + assert make_serializable("hello") == "hello" + assert make_serializable(True) is True + assert make_serializable(None) is None + + +def test_make_serializable_non_json_types(): + """Non-JSON types are converted to strings.""" + result = make_serializable({"key": {1, 2, 3}}) + assert isinstance(result["key"], str) + + +def test_deep_merge_simple(): + result = deep_merge({"a": 1, "b": 2}, {"b": 3, "c": 4}) + assert result == {"a": 1, "b": 3, "c": 4} + + +def test_deep_merge_nested(): + base = {"quant_cfg": {"default": {"enable": False}, "*weight*": {"num_bits": 8}}} + override = {"quant_cfg": {"*weight*": {"axis": 0}}} + result = deep_merge(base, override) + assert result["quant_cfg"]["default"] == {"enable": False} + assert result["quant_cfg"]["*weight*"] == {"num_bits": 8, "axis": 0} + + +def test_deep_merge_replaces_non_dict(): + result = deep_merge({"algorithm": "max"}, {"algorithm": "awq_lite"}) + assert result["algorithm"] == "awq_lite" + + +def test_deep_merge_no_mutation(): + base = {"a": {"b": 1}} + override = {"a": {"c": 2}} + result = deep_merge(base, override) + assert "c" not in base["a"] + assert result["a"] == {"b": 1, "c": 2}