diff --git a/docs/source/guides/1_quantization.rst b/docs/source/guides/1_quantization.rst index a838bfb106..ae0da29c27 100644 --- a/docs/source/guides/1_quantization.rst +++ b/docs/source/guides/1_quantization.rst @@ -19,6 +19,8 @@ Below, you can find the documentation for the quantization toolkit in ModelOpt: ./_basic_quantization.rst ./_choosing_quant_methods.rst ./_pytorch_quantization.rst + ./_quant_cfg.rst + ./_recipes.rst ./_customized_model_quantization.rst ./_compress_quantized_models.rst ./_onnx_quantization.rst diff --git a/docs/source/guides/_pytorch_quantization.rst b/docs/source/guides/_pytorch_quantization.rst index 15a7da9f16..3121f51d9b 100644 --- a/docs/source/guides/_pytorch_quantization.rst +++ b/docs/source/guides/_pytorch_quantization.rst @@ -237,14 +237,16 @@ For debugging purposes or simple customizations, you can modify an existing conf .. code-block:: python - # Create a copy of the default INT8 configuration - config = mtq.INT8_DEFAULT_CFG.copy() + import copy - # Disable input quantizers for all layers - config["quant_cfg"]["*input_quantizer"]["enable"] = False + # Create a deep copy of the default INT8 configuration + config = copy.deepcopy(mtq.INT8_DEFAULT_CFG) + + # Disable input quantizers for all layers (appended last, so it takes precedence) + config["quant_cfg"].append({"quantizer_path": "*input_quantizer", "enable": False}) # Disable all quantizers for layers matching the pattern "layer1.*" - config["quant_cfg"]["*layer1.*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*layer1.*", "enable": False}) Advanced Configuration Creation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -253,18 +255,23 @@ For exploring new quantization recipes, you can compose a completely new configu .. code-block:: python + from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg + # Custom configuration for INT4 block-wise weights and INT8 dynamic activations MY_CUSTOM_CONFIG = { - "quant_cfg": { + "quant_cfg": [ + # Disable all quantizers by default, then enable selectively + {"quantizer_path": "*", "enable": False}, + # Configure weight quantizers with 4-bit precision and 128-element blocks - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, "enable": True}, # Configure input quantizers with 8-bit dynamic quantization - "*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}}, # Include default disabled quantizer configurations - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } @@ -394,8 +401,10 @@ You can specify ``custom_calib`` as ``algorithm`` in ``quant_cfg`` to use it. He # create quantization configuration with "custom_calib" method quant_cfg = { - 'quant_cfg': {'*weight_quantizer': ..}, - 'algorithm': {"method": 'custom_calib'}, + 'quant_cfg': [ + {"quantizer_path": "*weight_quantizer", "cfg": {...}}, + ], + 'algorithm': {"method": 'custom_calib'}, } diff --git a/docs/source/guides/_quant_cfg.rst b/docs/source/guides/_quant_cfg.rst new file mode 100644 index 0000000000..0b5d9cf771 --- /dev/null +++ b/docs/source/guides/_quant_cfg.rst @@ -0,0 +1,393 @@ +.. _quant-cfg: + +====================================== +Quantization Configuration (quant_cfg) +====================================== + +The ``quant_cfg`` field is the primary mechanism for controlling which quantizers are active in a +model and how they are configured. This guide explains the format, ordering semantics, and common +patterns for composing quantization configurations. + +.. tip:: + + For the list of built-in configs and supported formats, see :any:`quantization-formats`. + For how to apply a config to a model, see :any:`_pytorch_quantization`. + +---------- + +Overview +======== + +A quantization config is a Python dictionary with two top-level keys: + +.. code-block:: python + + config = { + "quant_cfg": [...], # ordered list of QuantizerCfgEntry dicts + "algorithm": "max", # calibration algorithm + } + +The ``quant_cfg`` value is an **ordered list** of :class:`QuantizerCfgEntry +` dicts. Each entry targets a set of +quantizer modules in the model and specifies their configuration. + +---------- + +Entry Format +============ + +Each entry in the list is a dictionary with the following fields: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``quantizer_path`` + - Yes + - Wildcard string matched against quantizer module names (e.g. ``"*weight_quantizer"``). + Uses :func:`fnmatch` rules. + * - ``parent_class`` + - No + - Restricts matching to quantizers whose immediate parent module is of this PyTorch class + (e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class. + * - ``cfg`` + - No + - A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig + `, or a list of such dicts + for sequential quantization (see :ref:`sequential-quantizers`). + * - ``enable`` + - No + - ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``. + When ``cfg`` is absent, **only** the enabled/disabled state is changed — all other + attributes remain untouched. When ``cfg`` is present, ``enable`` sets the enabled state + of the newly-configured quantizer. When ``cfg`` is present and ``enable`` is omitted, + the quantizer is implicitly enabled (``True``). + +.. note:: + + Every entry must specify at least one of ``cfg`` or ``enable`` in addition to + ``quantizer_path``. An entry with only ``quantizer_path`` and no other keys is **invalid** + and will raise a ``ValueError`` at config-processing time. This prevents subtle bugs where + a bare ``{"quantizer_path": "*"}`` would silently behave as ``enable=True`` for all + quantizers. + +---------- + +Default Quantizer Configuration +================================ + +When a quantizer is enabled but has never been touched by a ``cfg`` entry — either because no +entry in the list matched it, or because it was only reached by enable-only entries — it operates +with the default attributes of +:class:`QuantizerAttributeConfig `: + +.. code-block:: python + + { + "num_bits": 8, # 8-bit integer quantization + "axis": None, # per-tensor scale (no per-channel axis) + "fake_quant": True, # simulate quantization in forward pass (PTQ / QAT) + "unsigned": False, # signed integer range, e.g. [-128, 127] for INT8 + "narrow_range": False, # full range; True would restrict to [-127, 127] for INT8 + "type": "static", # static calibration (not dynamic per-inference) + "block_sizes": None, # no block quantization; set for NF4 / MXFP formats + "bias": None, # no affine bias correction + "calibrator": "max", # use max-abs calibration to determine amax + "rotate": False, # no Hadamard rotation (QuaRot / SpinQuant) + "pass_through_bwd": True, # straight-through estimator for QAT gradients + "trt_high_precision_dtype": "Float", # cast QDQ nodes to fp32 for TRT StronglyType export + "backend": None, # use the built-in quantization backend + "backend_extra_args": None, # no extra args for custom backends + "use_constant_amax": False, # calibrate amax; True hard-codes FP8 E4M3 max (448.0) + } + +In practice this means an un-configured but enabled quantizer performs **INT8 per-tensor static +fake-quantization** with a max-calibrated scale. This is rarely the intended behavior — every +quantizer you want active should be explicitly configured with a ``cfg`` entry. + +---------- + +Ordering and Precedence +======================= + +Entries are applied **in list order**. Later entries override earlier ones for any quantizer they +match. This gives a clear, composable precedence model: + +- Put broad rules (e.g. deny-all) **first**. +- Put format-specific enable rules **after**. +- Put fine-grained exclusions (specific layers, classes) **last**. + +The recommended pattern used by all built-in configs is: + +.. code-block:: python + + "quant_cfg": [ + # 1. Deny all quantizers by default + {"quantizer_path": "*", "enable": False}, + + # 2. Enable and configure the target quantizers + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + + # 3. Apply standard exclusions last (BatchNorm, LM head, MoE routers, etc.) + *mtq.config._default_disabled_quantizer_cfg, + ] + +.. note:: + + The deny-all entry ``{"quantizer_path": "*", "enable": False}`` is available as + :data:`modelopt.torch.quantization.config._base_disable_all` and is prepended to every + built-in config. This ensures quantizers not explicitly targeted remain disabled. + +---------- + +Entry Atomicity +=============== + +Each ``cfg``-bearing entry in ``quant_cfg`` is a **complete, self-contained configuration unit**. +When an entry with ``cfg`` matches a quantizer, it **completely replaces** that quantizer's +configuration — it does not merge with or incrementally update settings left by earlier entries. + +Concretely, if an entry specifies only a subset of quantizer attributes (e.g. only ``num_bits``), +all unspecified attributes are filled in with their default values from +:class:`QuantizerAttributeConfig `. +The resulting *complete* config is then written to the quantizer, discarding whatever any prior +matching entry had set. + +This means: + +- **Last cfg-entry wins, fully.** If two entries both match ``*weight_quantizer`` and both carry + a ``cfg``, the second entry does not inherit the first entry's settings — it replaces them entirely. +- **No hidden state accumulation.** The final configuration of a quantizer depends only on the + *last* ``cfg``-bearing entry in the list that matched it, making behavior easy to reason about. +- **Changing one field requires a full spec.** Because each ``cfg`` entry is a complete replacement, + to change only one attribute of a quantizer that was already configured, you must reproduce the + full desired config in the new entry. Any attribute omitted from the entry will revert to its + default, not to the value set by an earlier entry. + +**Enable-only entries are the exception.** An entry with no ``cfg`` (only ``enable``) is *not* a +full replacement — it solely flips the on/off state of matched quantizers, leaving all other +attributes unchanged: + +- ``{"quantizer_path": "*", "enable": False}`` disables all quantizers without touching their + configured attributes. Use this as the first step in a deny-all-then-configure pattern. +- ``{"quantizer_path": "*weight_quantizer", "enable": True}`` (no ``cfg``) re-enables weight + quantizers using whatever attributes they currently carry (or their defaults if they were never + configured by a ``cfg`` entry). + +For example, given the following two entries both matching ``*weight_quantizer``: + +.. code-block:: python + + # Entry 1 — sets FP8 per-channel + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, + + # Entry 2 — sets INT4 blockwise (axis is NOT inherited from Entry 1) + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}}, + +After Entry 2 is applied, the quantizer has ``num_bits=4``, ``block_sizes={-1: 128}``, and +``axis=None`` (the default). The ``axis=0`` set by Entry 1 is gone. + +.. note:: + + The deny-all-then-configure pattern is safe and predictable precisely because + ``{"quantizer_path": "*", "enable": False}`` **only** disables quantizers without resetting + their attributes. Subsequent ``cfg`` entries then configure targets from a known default state. + +---------- + +Common Patterns +=============== + +Skipping Specific Layers +------------------------ + +Append a disable entry after the existing config to exclude layers matched by a path pattern. +Because it is appended last, it takes precedence over all earlier entries: + +.. code-block:: python + + import copy + import modelopt.torch.quantization as mtq + + config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + + # Skip the final projection layer + config["quant_cfg"].append({"quantizer_path": "*lm_head*", "enable": False}) + + model = mtq.quantize(model, config, forward_loop) + +Skipping Layers by Module Class +-------------------------------- + +Use ``parent_class`` to target quantizers only within a specific type of layer, leaving the +same quantizer path in other layer types unaffected: + +.. code-block:: python + + config["quant_cfg"].append({ + "quantizer_path": "*input_quantizer", + "parent_class": "nn.LayerNorm", + "enable": False, + }) + +Overriding Quantizer Precision for Specific Layers +--------------------------------------------------- + +A later entry with a matching ``quantizer_path`` replaces the configuration set by an earlier +entry. This allows per-layer precision overrides without restructuring the entire config: + +.. code-block:: python + + config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + + # Quantize attention output projections in higher-precision INT8 instead of FP8 + config["quant_cfg"].append({ + "quantizer_path": "*o_proj*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }) + +Building a Config from Scratch +------------------------------- + +For entirely custom recipes, compose the list directly: + +.. code-block:: python + + from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg + + MY_CUSTOM_CFG = { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } + + model = mtq.quantize(model, MY_CUSTOM_CFG, forward_loop) + +---------- + +.. _sequential-quantizers: + +Sequential Quantization +======================= + +When ``cfg`` is a **list** of attribute dicts, the matched +:class:`TensorQuantizer ` +is replaced with a +:class:`SequentialQuantizer ` +that applies each format in sequence. This is used, for example, in W4A8 quantization where weights +are quantized first in INT4 and then in FP8: + +.. code-block:: python + + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, # FP8 + ], + } + +---------- + +.. _migrating-from-dict-format: + +Migrating from Dict Format +=========================== + +Earlier versions of ModelOpt used a flat dictionary for ``quant_cfg``. The new list format is +preferred because it provides explicit ordering and unambiguous precedence. Existing dict-based +configs continue to work — the normalization layer converts them automatically — but new code +should use the list format. + +The table below shows common patterns and their list equivalents: + +.. list-table:: + :header-rows: 1 + :widths: 50 50 + + * - Legacy dict format + - New list format + * - .. code-block:: python + + "quant_cfg": { + "*weight_quantizer": { + "num_bits": 8, + "axis": 0, + }, + "*input_quantizer": { + "num_bits": 8, + "axis": None, + }, + "default": {"enable": False}, + } + + - .. code-block:: python + + "quant_cfg": [ + {"quantizer_path": "*", + "enable": False}, + {"quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}}, + ] + + * - .. code-block:: python + + # Disable by key assignment + config["quant_cfg"]["*lm_head*"] = { + "enable": False, + } + + - .. code-block:: python + + # Append to the end (last entry wins) + config["quant_cfg"].append( + {"quantizer_path": "*lm_head*", + "enable": False} + ) + + * - .. code-block:: python + + # Class-scoped entry + "quant_cfg": { + "nn.Linear": { + "*input_quantizer": { + "enable": False, + }, + }, + } + + - .. code-block:: python + + "quant_cfg": [ + {"quantizer_path": "*input_quantizer", + "parent_class": "nn.Linear", + "enable": False}, + ] + +Key differences to keep in mind: + +- The ``"default"`` key becomes ``{"quantizer_path": "*", "enable": False}`` placed at the + **start** of the list (deny-all-then-configure pattern). +- Dict key assignment (``config["quant_cfg"]["*lm_head*"] = ...``) becomes ``list.append()``. + Because later entries override earlier ones, appending achieves the same override effect. +- ``nn.*``-scoped dict keys become entries with a ``parent_class`` field. + +---------- + +Reference +========= + +- :class:`QuantizerCfgEntry ` +- :class:`QuantizerAttributeConfig ` +- :class:`QuantizeConfig ` +- :func:`set_quantizer_by_cfg ` diff --git a/docs/source/guides/_recipes.rst b/docs/source/guides/_recipes.rst new file mode 100644 index 0000000000..246d70b20c --- /dev/null +++ b/docs/source/guides/_recipes.rst @@ -0,0 +1,423 @@ +.. _recipes: + +Recipes +####### + +A **recipe** is a declarative YAML specification that fully describes how to optimize a model. +Recipes decouple optimization settings from Python code, enabling reuse, sharing, version +control, and reproducibility. Instead of editing Python scripts to change quantization +parameters, you author (or select) a recipe file and pass it to the ModelOpt tooling. + +.. contents:: On this page + :local: + :depth: 2 + + +Motivation +========== + +Without recipes, optimization settings are scattered across command-line arguments, Python +constants, and ad-hoc code edits. This makes it difficult to: + +* **Reproduce** a published result -- the exact configuration is buried in script arguments. +* **Share** a configuration -- there is no single artifact to hand off. +* **Version-control** changes -- diffs are mixed in with unrelated code changes. +* **Onboard new models** -- inference engineers must read source code to discover which + settings to tweak. + +Recipes solve these problems by capturing **all** the configuration needed to optimize a +model in a single YAML file (or a small directory of files). + + +Design overview +=============== + +The recipe system is part of the :mod:`modelopt.recipe` package and consists of three +layers: + +1. **Recipe files** -- YAML documents stored in the ``modelopt_recipes/`` directory (shipped + with the package) or on the user's filesystem. +2. **Config loader** -- :func:`~modelopt.recipe.load_config` reads YAML files, resolves + paths, and performs automatic ``ExMy`` floating-point notation conversion. +3. **Recipe loader** -- :func:`~modelopt.recipe.load_recipe` validates the YAML against + Pydantic models and returns a typed recipe object ready for use. + + +Recipe file format +================== + +A recipe is a YAML file with two top-level sections: ``metadata`` and a +type-specific configuration section (currently ``ptq_cfg`` for PTQ recipes). + +Single-file format +------------------ + +The simplest form is a single ``.yml`` or ``.yaml`` file: + +.. code-block:: yaml + + # modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml + + metadata: + recipe_type: ptq + description: FP8 per-tensor weight and activation (W8A8), FP8 KV cache, max calibration. + + ptq_cfg: + algorithm: max + quant_cfg: + - quantizer_path: '*' + enable: false + - quantizer_path: '*input_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*weight_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + # ... standard exclusions omitted for brevity + +Directory format +---------------- + +For larger recipes or when you want to keep metadata separate from the +quantization configuration, use a directory with two files: + +.. code-block:: text + + my_recipe/ + recipe.yml # metadata section + ptq_cfg.yml # ptq_cfg section (quant_cfg + algorithm) + +``recipe.yml``: + +.. code-block:: yaml + + metadata: + recipe_type: ptq + description: My custom NVFP4 recipe. + +``ptq_cfg.yml``: + +.. code-block:: yaml + + algorithm: max + quant_cfg: + - quantizer_path: '*' + enable: false + - quantizer_path: '*weight_quantizer' + cfg: + num_bits: e2m1 + block_sizes: {-1: 16, type: dynamic, scale_bits: e4m3} + - quantizer_path: '*input_quantizer' + cfg: + num_bits: e4m3 + axis: + + +Metadata section +================ + +Every recipe file must contain a ``metadata`` mapping with at least a ``recipe_type`` field: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``recipe_type`` + - Yes + - The optimization category. Currently only ``"ptq"`` is supported. + * - ``description`` + - No + - A human-readable summary of what the recipe does. + + +PTQ configuration section +========================= + +For PTQ recipes (``recipe_type: ptq``), the ``ptq_cfg`` mapping contains: + +.. list-table:: + :header-rows: 1 + :widths: 20 15 65 + + * - Field + - Required + - Description + * - ``quant_cfg`` + - Yes + - An ordered list of :class:`~modelopt.torch.quantization.config.QuantizerCfgEntry` + dicts. See :ref:`quant-cfg` for the full specification of entries, ordering + semantics, and atomicity rules. + * - ``algorithm`` + - No + - The calibration algorithm: ``"max"`` (default), ``"mse"``, ``"smoothquant"``, + ``"awq_lite"``, ``"awq_full"``, ``"awq_clip"``, ``"gptq"``, or ``null`` for + formats that need no calibration (e.g. MX formats). + + +ExMy floating-point notation +============================= + +Recipe files support a convenient shorthand for floating-point bit formats in +``num_bits`` and ``scale_bits`` fields. Instead of writing a Python tuple, you +write the format name directly: + +.. code-block:: yaml + + num_bits: e4m3 # automatically converted to (4, 3) + scale_bits: e8m0 # automatically converted to (8, 0) + +The notation is case-insensitive (``E4M3``, ``e4m3``, ``E4m3`` all work). The +conversion is performed by :func:`~modelopt.recipe.load_config` when loading any +YAML file, so it works in both recipe files and standalone config files. + +Common formats: + +.. list-table:: + :header-rows: 1 + :widths: 15 15 70 + + * - Notation + - Tuple + - Description + * - ``e4m3`` + - ``(4, 3)`` + - FP8 E4M3 -- standard FP8 weight/activation format + * - ``e5m2`` + - ``(5, 2)`` + - FP8 E5M2 -- wider dynamic range, used for gradients + * - ``e2m1`` + - ``(2, 1)`` + - FP4 E2M1 -- NVFP4 weight format + * - ``e8m0`` + - ``(8, 0)`` + - E8M0 -- MX block scaling format + + +Built-in recipes +================ + +ModelOpt ships a library of built-in recipes under the ``modelopt_recipes/`` package. +These are bundled with the Python distribution and can be referenced by their relative +path (without the ``modelopt_recipes/`` prefix). + +General PTQ recipes +------------------- + +General recipes are model-agnostic and apply to any supported architecture: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Recipe path + - Description + * - ``general/ptq/fp8_default-fp8_kv`` + - FP8 per-tensor W8A8, FP8 KV cache, max calibration + * - ``general/ptq/nvfp4_default-fp8_kv`` + - NVFP4 W4A4 with FP8 KV cache, max calibration + * - ``general/ptq/nvfp4_mlp_only-fp8_kv`` + - NVFP4 for MLP layers only, FP8 KV cache + * - ``general/ptq/nvfp4_experts_only-fp8_kv`` + - NVFP4 for MoE expert layers only, FP8 KV cache + * - ``general/ptq/nvfp4_omlp_only-fp8_kv`` + - NVFP4 for output projection + MLP layers, FP8 KV cache + +Model-specific recipes +---------------------- + +Model-specific recipes are tuned for a particular architecture and live under +``models//``: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Recipe path + - Description + * - ``models/Step3.5-Flash/nvfp4-mlp-only`` + - NVFP4 MLP-only for Step 3.5 Flash MoE model + + +Loading recipes +=============== + +Python API +---------- + +Use :func:`~modelopt.recipe.load_recipe` to load a recipe. The path is resolved +against the built-in library first, then the filesystem: + +.. code-block:: python + + from modelopt.recipe import load_recipe, ModelOptPTQRecipe + + # Load a built-in recipe by relative path (suffix optional) + recipe = load_recipe("general/ptq/fp8_default-fp8_kv") + assert isinstance(recipe, ModelOptPTQRecipe) + + # The ptq_cfg dict can be passed directly to mtq.quantize() + import modelopt.torch.quantization as mtq + + model = mtq.quantize(model, recipe.ptq_cfg, forward_loop) + +.. code-block:: python + + # Load a custom recipe from the filesystem + recipe = load_recipe("/path/to/my_custom_recipe.yml") + model = mtq.quantize(model, recipe.ptq_cfg, forward_loop) + +Command-line usage +------------------ + +The ``hf_ptq.py`` example accepts a ``--recipe`` flag: + +.. code-block:: bash + + python examples/llm_ptq/hf_ptq.py \ + --model Qwen/Qwen3-8B \ + --recipe general/ptq/fp8_default-fp8_kv \ + --export_path build/fp8 \ + --calib_size 512 \ + --export_fmt hf + +When ``--recipe`` is provided, the script loads the recipe and uses its ``ptq_cfg`` +directly, bypassing the ``--qformat`` / ``--kv_cache_qformat`` flags. + + +Loading standalone configs +-------------------------- + +:func:`~modelopt.recipe.load_config` loads arbitrary YAML config files with +automatic ``ExMy`` conversion and built-in path resolution. This is useful +for loading shared configuration fragments: + +.. code-block:: python + + from modelopt.recipe import load_config + + cfg = load_config("configs/some_shared_config") + + +Path resolution +=============== + +Both :func:`~modelopt.recipe.load_recipe` and :func:`~modelopt.recipe.load_config` +resolve paths using the same strategy: + +1. If the path is absolute, use it directly. +2. If relative, check the **built-in recipes library** first + (``modelopt_recipes/``), probing ``.yml`` and ``.yaml`` suffixes. +3. Then check the **filesystem**, probing the same suffixes. + +This means built-in recipes can be referenced without any prefix: + +.. code-block:: python + + # These are all equivalent: + load_recipe("general/ptq/fp8_default-fp8_kv") + load_recipe("general/ptq/fp8_default-fp8_kv.yml") + + +Writing a custom recipe +======================= + +To create a custom recipe: + +1. Start from an existing recipe that is close to your target configuration. +2. Copy it and modify the ``quant_cfg`` entries as needed (see :ref:`quant-cfg` + for entry format details). +3. Update the ``metadata.description`` to describe your changes. +4. Save the file and pass its path to ``load_recipe()`` or ``--recipe``. + +Example -- creating an INT8 per-channel recipe: + +.. code-block:: yaml + + # my_int8_recipe.yml + metadata: + recipe_type: ptq + description: INT8 per-channel weight, per-tensor activation. + + ptq_cfg: + algorithm: max + quant_cfg: + - quantizer_path: '*' + enable: false + - quantizer_path: '*weight_quantizer' + cfg: + num_bits: 8 + axis: 0 + - quantizer_path: '*input_quantizer' + cfg: + num_bits: 8 + axis: + - quantizer_path: '*lm_head*' + enable: false + - quantizer_path: '*output_layer*' + enable: false + + +Recipe repository layout +======================== + +The ``modelopt_recipes/`` package is organized as follows: + +.. code-block:: text + + modelopt_recipes/ + +-- __init__.py + +-- general/ # Model-agnostic recipes + | +-- ptq/ + | +-- fp8_default-fp8_kv.yml + | +-- nvfp4_default-fp8_kv.yml + | +-- nvfp4_mlp_only-fp8_kv.yml + | +-- nvfp4_experts_only-fp8_kv.yml + | +-- nvfp4_omlp_only-fp8_kv.yml + +-- models/ # Model-specific recipes + | +-- Step3.5-Flash/ + | +-- nvfp4-mlp-only.yaml + +-- configs/ # Shared configuration fragments + + +Recipe data model +================= + +Recipes are validated at load time using Pydantic models: + +:class:`~modelopt.recipe.config.ModelOptRecipeBase` + Base class for all recipe types. Contains ``recipe_type`` and ``description``. + +:class:`~modelopt.recipe.config.ModelOptPTQRecipe` + PTQ-specific recipe. Adds the ``ptq_cfg`` field (a dict with ``quant_cfg`` and + ``algorithm``). + +:class:`~modelopt.recipe.config.RecipeType` + Enum of supported recipe types. Currently only ``PTQ``. + + +Future directions +================= + +The recipe system is designed to grow: + +* **QAT recipes** -- ``recipe_type: qat`` with training hyperparameters, distillation + settings, and dataset configuration. +* **Sparsity recipes** -- structured and unstructured pruning configurations. +* **Speculative decoding recipes** -- draft model and vocabulary calibration settings. +* **Composite recipes** -- chaining multiple optimization stages + (e.g., quantize then prune) in a single recipe. +* **Dataset configuration** -- standardized ``dataset`` section for calibration data + specification. +* **Recipe merging and override utilities** -- programmatic tools to compose and + customize recipes. +* **Unified entry point** -- a ``nv-modelopt`` CLI that accepts ``--recipe`` as the + primary configuration mechanism, replacing per-example scripts. diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index bcfd9de409..b4ae18711c 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -306,41 +306,80 @@ def calibrate_loop(model): dist.barrier() ## quant config - mtq_cfg = getattr(mtq, quant_cfg) + import copy + + mtq_cfg = copy.deepcopy(getattr(mtq, quant_cfg)) # disable head that corresponds to lm_head (for the huggingface checkpoint) - mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} + mtq_cfg["quant_cfg"].append({"quantizer_path": "*head*", "enable": False}) allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" if not mla_quant: - mtq_cfg["quant_cfg"]["*attn*"] = {"enable": False} + mtq_cfg["quant_cfg"].append({"quantizer_path": "*attn*", "enable": False}) elif mla_quant == "per_tensor_fp8": - mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} - mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} + mtq_cfg["quant_cfg"].extend( + [ + { + "quantizer_path": "*attn*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*attn*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + ] + ) elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] for layer in mla_linear_layers: if layer in mla_nvfp4_linear_layers: # wq_a, wkv_a, wq_b, wo use NVFP4 quantization - mtq_cfg["quant_cfg"][layer + "_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } + mtq_cfg["quant_cfg"].append( + { + "quantizer_path": layer + "_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) else: - mtq_cfg["quant_cfg"][layer + "_quantizer"] = {"enable": False} + mtq_cfg["quant_cfg"].append( + {"quantizer_path": layer + "_quantizer", "enable": False} + ) # Disable BMM quantizers - mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} - mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} + mtq_cfg["quant_cfg"].extend( + [ + {"quantizer_path": "*attn.kv_bmm_quantizer*", "enable": False}, + {"quantizer_path": "*attn.pe_bmm_quantizer*", "enable": False}, + ] + ) if not args.disable_wo_quant and "FP4" in quant_cfg: - mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] - mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] + # Find the default input/weight quantizer cfgs to swap for wo layers. + # cfg may be a list (SequentialQuantizer); use the first element in that case. + input_cfg = mtq.find_quant_cfg_entry_by_path(mtq_cfg["quant_cfg"], "*input_quantizer")[ + "cfg" + ] + weight_cfg = mtq.find_quant_cfg_entry_by_path(mtq_cfg["quant_cfg"], "*weight_quantizer")[ + "cfg" + ] + if isinstance(input_cfg, list): + input_cfg = input_cfg[0] + if isinstance(weight_cfg, list): + weight_cfg = weight_cfg[0] + mtq_cfg["quant_cfg"].extend( + [ + {"quantizer_path": "*wo*weight_quantizer", "cfg": input_cfg}, + {"quantizer_path": "*wo*input_quantizer", "cfg": weight_cfg}, + ] + ) ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 94063ffd9c..9f24ec15f8 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -17,82 +17,79 @@ from calib.plugin_calib import PercentileCalibrator FP8_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"enable": False}, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } INT8_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - "*output_quantizer": {"enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + {"quantizer_path": "*output_quantizer", "enable": False}, + ], "algorithm": "max", } NVFP4_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": { - "**weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "**weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "**input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "**input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "*softmax_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "*bmm2_output_quantizer": { - "num_bits": (4, 3), - "axis": None, - }, - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + {"quantizer_path": "*[qkv]_bmm_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*bmm2_output_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": {"method": "svdquant", "lowrank": 32}, } @@ -106,8 +103,9 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, ** algo_cfg["lowrank"] = kwargs["lowrank"] quant_config["algorithm"] = algo_cfg - for p in quant_config["quant_cfg"].values(): - if "num_bits" in p and "trt_high_precision_dtype" not in p: + for entry in quant_config["quant_cfg"]: + p = entry.get("cfg", {}) + if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p: p["trt_high_precision_dtype"] = trt_high_precision_dtype @@ -127,18 +125,23 @@ def reset_set_int8_config(quant_config, percentile, n_steps, collect_method, bac for name, module in backbone.named_modules(): if isinstance(module, nn.Conv2d): aq_name = f"*{name}*input_quantizer*" - quant_config["quant_cfg"][aq_name] = { - "num_bits": 8, - "axis": None, - "calibrator": ( - PercentileCalibrator, - (), - { + quant_config["quant_cfg"].append( + { + "quantizer_path": aq_name, + "cfg": { "num_bits": 8, "axis": None, - "percentile": percentile, - "total_step": n_steps, - "collect_method": collect_method, + "calibrator": ( + PercentileCalibrator, + (), + { + "num_bits": 8, + "axis": None, + "percentile": percentile, + "total_step": n_steps, + "collect_method": collect_method, + }, + ), }, - ), - } + } + ) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 612357f6ea..cb4b1e0032 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -137,7 +137,12 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: - quant_config["quant_cfg"]["*[qkv]_bmm_quantizer"] = {"num_bits": (4, 3), "axis": None} # type: ignore[index] + quant_config["quant_cfg"].append( + { + "quantizer_path": "*[qkv]_bmm_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) set_quant_config_attr( quant_config, self.model_config.trt_high_precision_dtype.value, diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index e9ecb0731f..73308ed7f7 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -100,11 +100,21 @@ def loss_func(output, data): if enable_kv_cache_quantization: mtq.set_quantizer_by_cfg( model, - quant_cfg={"*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}}, + quant_cfg=[ + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + } + ], ) # Lets calibrate only the output quantizer this time. Let's disable all other quantizers. with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, "*output_quantizer": {"enable": True}} + model, + [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "enable": True}, + ], ): mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) return model diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 3df44115a2..466f65ceda 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -33,12 +33,20 @@ # Modify your custom config for debugging or research purposes. CUSTOM_CONFIG = { "MY_QUANT_CONFIG": { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + "quant_cfg": [ + *mtq.config._base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "type": "dynamic", "block_sizes": {-1: None}}, + }, # Disable sensitive layers such as `lm_head`, gate layers in MoE etc. - **mtq.config._default_disabled_quantizer_cfg, - }, + *mtq.config._default_disabled_quantizer_cfg, + ], "algorithm": "max", }, } diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 58eb676111..ad2f7ca09b 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -205,7 +205,12 @@ def build_quant_cfg( ) -> dict[str, Any]: quant_cfg = copy.deepcopy(quant_cfg) if "awq" in str(quant_cfg.get("algorithm")): - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] + from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path + + weight_quantizer_entry = find_quant_cfg_entry_by_path( + quant_cfg["quant_cfg"], "*weight_quantizer" + ) + weight_quantizer = weight_quantizer_entry.get("cfg") or {} if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] # If awq_block_size argument is provided, update weight_quantizer @@ -236,10 +241,10 @@ def build_quant_cfg( if model_type == "phi4mm": # Only quantize the language model - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": "*speech*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*audio*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*image*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*vision*", "enable": False}) return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index b81dc60c01..c72b0bd81e 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -78,16 +78,23 @@ RAND_SEED = 1234 -def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: +def _set_kv_cache_constant_amax(quant_cfg: list) -> None: """Set use_constant_amax on KV cache quantizers. Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. """ - if "*[kv]_bmm_quantizer" in quant_cfg: - quant_cfg["*[kv]_bmm_quantizer"] = { - **quant_cfg["*[kv]_bmm_quantizer"], - "use_constant_amax": True, + for i, entry in enumerate(quant_cfg): + if entry.get("quantizer_path") != "*[kv]_bmm_quantizer": + continue + assert isinstance(entry.get("cfg", {}), dict) + new_entry = { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": {**entry.get("cfg", {}), "use_constant_amax": True}, } + if entry.get("enable") is not None: + new_entry["enable"] = entry["enable"] + quant_cfg[i] = new_entry + break QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { @@ -145,7 +152,7 @@ def extract_and_prepare_language_model_from_vl(full_model): # Apply disabled quant to all modules that are not part of language_model # This excludes them during HF export disabled_quant_cfg = { - "quant_cfg": {"default": {"enable": False}}, + "quant_cfg": [{"quantizer_path": "*", "enable": False}], "algorithm": "max", } @@ -319,7 +326,11 @@ def forward_step(model, batch): ), verbose=True, # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=list(_default_disabled_quantizer_cfg.keys()), + disabled_layers=[ + entry["quantizer_path"] + for entry in _default_disabled_quantizer_cfg + if "parent_class" not in entry + ], method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, ) @@ -332,7 +343,9 @@ def forward_step(model, batch): kv_cache_quant_cfg = copy.deepcopy( getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] ) - kv_cache_quant_cfg.pop("default", None) # keep other quantizers from auto_quantize + kv_cache_quant_cfg = [ + e for e in kv_cache_quant_cfg if e["quantizer_path"] != "*" + ] # keep other quantizers from auto_quantize if args.kv_cache_qformat in _KV_CAST_FORMATS: _set_kv_cache_constant_amax(kv_cache_quant_cfg) @@ -341,7 +354,8 @@ def forward_step(model, batch): if args.kv_cache_qformat not in _KV_CAST_FORMATS: # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( - language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} + language_model, + [{"quantizer_path": "*", "enable": False}, *kv_cache_quant_cfg], ): mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) return language_model @@ -546,13 +560,17 @@ def mono_quantize( # For Nemotron VL models, disable quantization of vision components if is_nemotron_vl_model: print("Disabling quantization for vision components in Nemotron VL model") - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": "*vision*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*image*", "enable": False}) # Also disable radio model components specifically (for Nemotron-Parse) - quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} - quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} - quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder - quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific + quant_cfg["quant_cfg"].append({"quantizer_path": "*radio*", "enable": False}) + quant_cfg["quant_cfg"].append({"quantizer_path": "*visual*", "enable": False}) + quant_cfg["quant_cfg"].append( + {"quantizer_path": "*encoder*", "enable": False} + ) # Disable encoder + quant_cfg["quant_cfg"].append( + {"quantizer_path": "*model_encoder*", "enable": False} + ) # Nemotron-Parse specific print("Quantization will only be applied to the decoder (text generation) component") if not model_is_already_quantized or calibration_only: @@ -981,7 +999,7 @@ def quantize_main( for prefix in mtp_layer_prefixes: # Add exclusion pattern for this MTP layer (e.g., "*layers.92*") pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*" - quant_cfg["quant_cfg"][pattern] = {"enable": False} + quant_cfg["quant_cfg"].append({"quantizer_path": pattern, "enable": False}) print(f"Excluding MTP layer from quantization: {pattern}") # Use constant amax for KV quantizers when a cast format is selected. diff --git a/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb b/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb index fc055cf848..88599f2aac 100644 --- a/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb +++ b/examples/llm_ptq/notebooks/2_PTQ_AWQ_Calibration.ipynb @@ -189,17 +189,7 @@ "id": "a3ce3b47-48ac-4a27-a5ed-351a10c104a9", "metadata": {}, "outputs": [], - "source": [ - "# Get default AWQ config and optionally adjust block size\n", - "quant_cfg = mtq.INT4_AWQ_CFG\n", - "weight_quantizer = quant_cfg[\"quant_cfg\"][\"*weight_quantizer\"]\n", - "if isinstance(weight_quantizer, list):\n", - " weight_quantizer = weight_quantizer[0]\n", - "weight_quantizer[\"block_sizes\"][-1] = 128 # Optional: override block size\n", - "\n", - "# Apply AWQ quantization\n", - "model = mtq.quantize(model, quant_cfg, forward_loop=forward_loop)" - ] + "source": "import copy\n\nfrom modelopt.torch.quantization.config import find_quant_cfg_entry_by_path\n\n# Get default AWQ config and optionally adjust block size\nquant_cfg = copy.deepcopy(mtq.INT4_AWQ_CFG)\nweight_quantizer_entry = find_quant_cfg_entry_by_path(quant_cfg[\"quant_cfg\"], \"*weight_quantizer\")\nweight_quantizer = weight_quantizer_entry.get(\"cfg\", {})\nif isinstance(weight_quantizer, list):\n weight_quantizer = weight_quantizer[0]\nweight_quantizer[\"block_sizes\"][-1] = 128 # Optional: override block size\n\n# Apply AWQ quantization\nmodel = mtq.quantize(model, quant_cfg, forward_loop=forward_loop)" }, { "cell_type": "markdown", @@ -308,4 +298,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb b/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb index 122569489e..9634c615d9 100644 --- a/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb +++ b/examples/llm_ptq/notebooks/3_PTQ_AutoQuantization.ipynb @@ -288,7 +288,9 @@ " mtq.set_quantizer_by_cfg(model, quant_cfg=kv_cfg)\n", "\n", " # Calibrate **only** those quantizers\n", - " with mtq.set_quantizer_by_cfg_context(model, {\"*\": {\"enable\": False}, **kv_cfg}):\n", + " with mtq.set_quantizer_by_cfg_context(\n", + " model, [{\"quantizer_path\": \"*\", \"enable\": False}, *kv_cfg]\n", + " ):\n", " mtq.calibrate(model, algorithm=\"max\", forward_loop=forward_loop)\n", "else:\n", " print(\"KV cache left unquantized.\")" @@ -427,4 +429,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index 9435157259..14d5a5c829 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -54,12 +54,20 @@ CUSTOM_QUANT_CFG = { "INT4_WEIGHT_INT8_ACTIVATIONS": { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "*lm_head*": {"enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + {"quantizer_path": "*lm_head*", "enable": False}, + ], "algorithm": "max", } } diff --git a/examples/vllm_serve/vllm_ptq_utils.py b/examples/vllm_serve/vllm_ptq_utils.py index d6c055709d..e31f552000 100644 --- a/examples/vllm_serve/vllm_ptq_utils.py +++ b/examples/vllm_serve/vllm_ptq_utils.py @@ -102,7 +102,7 @@ def calibrate_loop(model: Any) -> None: return calibrate_loop -def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) -> dict[str, Any]: +def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: """Update KV cache quantization config for MLA models. MLA uses `kv_c_bmm_quantizer` (compressed KV) instead of separate @@ -117,18 +117,37 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: dict[str, Any]) if not any(isinstance(m, MLAAttention) for m in model.modules()): return kv_quant_cfg - if kv_config := kv_quant_cfg.get("*[kv]_bmm_quantizer"): - kv_quant_cfg["*kv_c_bmm_quantizer"] = kv_config - kv_quant_cfg["*k_pe_bmm_quantizer"] = kv_config + kv_entry = next( + ( + e + for e in kv_quant_cfg + if isinstance(e, dict) and e.get("quantizer_path") == "*[kv]_bmm_quantizer" + ), + None, + ) + if kv_entry is not None: + kv_config = kv_entry.get("cfg", {}) + kv_quant_cfg.append( + {"quantizer_path": "*kv_c_bmm_quantizer", "cfg": kv_config, "enable": True} + ) + kv_quant_cfg.append( + {"quantizer_path": "*k_pe_bmm_quantizer", "cfg": kv_config, "enable": True} + ) print("MLA detected: added *kv_c_bmm_quantizer and k_pe_bmm_quantizer config") return kv_quant_cfg def get_quant_config(quant_config: dict[str, Any], model: Any) -> dict[str, Any]: - quant_cfg = getattr(mtq, quant_config["quant_cfg"]) if quant_config["quant_cfg"] else {} + import copy + + quant_cfg = ( + copy.deepcopy(getattr(mtq, quant_config["quant_cfg"])) if quant_config["quant_cfg"] else {} + ) quant_kv_cfg = ( - getattr(mtq, quant_config["kv_quant_cfg"]) if quant_config["kv_quant_cfg"] else {} + copy.deepcopy(getattr(mtq, quant_config["kv_quant_cfg"])) + if quant_config["kv_quant_cfg"] + else {} ) # Check if model has MLA and update KV config accordingly diff --git a/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py b/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py index a861493b37..237531f7d7 100644 --- a/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py +++ b/examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py @@ -257,26 +257,20 @@ def build_quant_config( if exclude_blocks is None: exclude_blocks = [0, 1, 46, 47] - quant_cfg = { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + _nvfp4_cfg = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, } - - for pattern in SENSITIVE_LAYER_PATTERNS: - quant_cfg[pattern] = {"enable": False} - - for block_idx in exclude_blocks: - quant_cfg[f"*transformer_blocks.{block_idx}.*"] = {"enable": False} + quant_cfg = [ + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg, "enable": True}, + *[{"quantizer_path": pattern, "enable": False} for pattern in SENSITIVE_LAYER_PATTERNS], + *[ + {"quantizer_path": f"*transformer_blocks.{i}.*", "enable": False} + for i in exclude_blocks + ], + ] return { "quant_cfg": quant_cfg, diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 61f551b634..54f27d3c4d 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -15,6 +15,7 @@ """Quantization utilities for LLM models.""" +import copy import time import modelopt.torch.quantization as mtq @@ -57,35 +58,58 @@ def calibrate_loop(model): def get_quant_config(precision, lm_head_precision="fp16"): """Get the quantization configuration.""" if precision == "fp8": - quant_cfg = mtq.FP8_DEFAULT_CFG + quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG) elif precision == "nvfp4": - quant_cfg = mtq.NVFP4_DEFAULT_CFG + quant_cfg = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) elif precision == "int4_awq": - quant_cfg = mtq.INT4_AWQ_CFG + quant_cfg = copy.deepcopy(mtq.INT4_AWQ_CFG) # type: ignore[arg-type] else: raise ValueError(f"Unsupported precision: {precision}") - config_dict = quant_cfg["quant_cfg"] # type: dict + quant_cfg_list: list = [ + e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_path" in e + ] if lm_head_precision == "fp8": - config_dict["*lm_head.input_quantizer"] = {"num_bits": (4, 3), "axis": None} - config_dict["*lm_head.weight_quantizer"] = {"num_bits": (4, 3), "axis": None} + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + } + ) elif lm_head_precision == "nvfp4": - config_dict["*lm_head.input_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } - config_dict["*lm_head.weight_quantizer"] = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - } + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) + quant_cfg_list.append( + { + "quantizer_path": "*lm_head.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, + "enable": True, + } + ) + quant_cfg["quant_cfg"] = quant_cfg_list return quant_cfg diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 14a12bcdf3..3433fe5f7b 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -218,7 +218,10 @@ def _output_hook(module, input, output): # Run dummy forward pass to collect modules sharing same input try: - with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + with ( + torch.no_grad(), + set_quantizer_by_cfg_context(model, [{"quantizer_path": "*", "enable": False}]), + ): dummy_forward_fn() finally: # Always remove hooks diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 7907c79bd6..85911e6165 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -62,9 +62,22 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if isinstance(quantizer_attr_cfg, list): + if not quantizer_attr_cfg: + return 1.0 return min(estimate_quant_compression_for_quantizer(q) for q in quantizer_attr_cfg) if isinstance(quantizer_attr_cfg, dict): - return estimate_quant_compression_for_quantizer(list(quantizer_attr_cfg.values())) + # Handle raw quantizer cfg dicts (e.g. {"num_bits": (4, 3), "axis": None}) + if not quantizer_attr_cfg.get("enable", True): + return 1.0 + num_bits = quantizer_attr_cfg.get("num_bits") + if num_bits is None: + return 1.0 + if isinstance(num_bits, tuple): + return (sum(num_bits) + 1) / 16 + elif isinstance(num_bits, int): + return num_bits / 16 + else: + raise ValueError(f"Unknown quantization config {num_bits}") if isinstance(quantizer_attr_cfg, QuantizerAttributeConfig): if not quantizer_attr_cfg.enable: @@ -80,7 +93,14 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): raise ValueError(f"Unknown type {type(quantizer_attr_cfg)}, {quantizer_attr_cfg}") - return estimate_quant_compression_for_quantizer(list(quant_cfg.quant_cfg.values())) + cfgs = [] + for e in quant_cfg.quant_cfg: + if e.get("enable", True) is False: + continue + c = e.get("cfg") + if c is not None: + cfgs.append(c) + return estimate_quant_compression_for_quantizer(cfgs) if cfgs else 1.0 class QuantRecipe(CustomHPType): @@ -97,7 +117,7 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No name = self.get_auto_name_for_config(quant_cfg) or name if quant_cfg is None: - quant_cfg = {"quant_cfg": {"*": {"enable": False}}} + quant_cfg = {"quant_cfg": [{"quantizer_path": "*", "enable": False}]} elif isinstance(quant_cfg, str): assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" quant_cfg = getattr(mtq_config, quant_cfg) @@ -109,9 +129,7 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg["*output_quantizer"] = mtq_config.QuantizerAttributeConfig( - enable=False - ) + self.config.quant_cfg.append({"quantizer_path": "*output_quantizer", "enable": False}) self.compression = estimate_quant_compression(self.config) @@ -1300,21 +1318,9 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): else: best_recipe = search_state["best"]["recipe"] - quant_cfg: dict[str, Any] = {"*": {"enable": False}} - for hparam_name, recipe in best_recipe.items(): - if recipe == QuantRecipe(quant_cfg=None): - continue - module_names = search_state["candidate_stats"][hparam_name]["module_names"] - for module_name in module_names: - for quantizer_attr in ("input_quantizer", "weight_quantizer"): - matched_cfg = _match_quantizer_cfg(recipe.config.quant_cfg, quantizer_attr) - if matched_cfg is not None: - quant_cfg[f"{module_name}.{quantizer_attr}"] = matched_cfg - def _cfg_to_dict(v): if isinstance(v, mtq_config.QuantizerAttributeConfig): return { - "enable": v.enable, "num_bits": v.num_bits, **v.model_dump(exclude_defaults=True), } @@ -1322,7 +1328,24 @@ def _cfg_to_dict(v): return [_cfg_to_dict(c) for c in v] return v - quant_cfg = {k: _cfg_to_dict(v) for k, v in quant_cfg.items()} + quant_cfg: list[dict] = [{"quantizer_path": "*", "enable": False}] + for hparam_name, recipe in best_recipe.items(): + if recipe == QuantRecipe(quant_cfg=None): + continue + module_names = search_state["candidate_stats"][hparam_name]["module_names"] + for module_name in module_names: + for quantizer_attr in ("input_quantizer", "weight_quantizer"): + matched_cfg, matched_enable = _match_quantizer_cfg( + recipe.config.quant_cfg, quantizer_attr + ) + if matched_cfg is not None: + quant_cfg.append( + { + "quantizer_path": f"{module_name}.{quantizer_attr}", + "cfg": _cfg_to_dict(matched_cfg), + "enable": matched_enable, + } + ) warnings.warn( "get_auto_quantize_config: returned config uses algorithm='max'. " "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. " @@ -1362,9 +1385,19 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): def _match_quantizer_cfg(quant_cfg, quantizer_attr): - # Last-match-wins to mirror set_quantizer_by_cfg behavior + # Last-match-wins to mirror set_quantizer_by_cfg behavior. + # Patterns may be path-scoped (e.g. "*mlp*weight_quantizer") while quantizer_attr + # is a bare name like "weight_quantizer". We match if the bare name matches directly + # OR if the pattern ends with the bare quantizer_attr (path-scoped match). matched = None - for pattern, cfg in quant_cfg.items(): - if fnmatch.fnmatch(quantizer_attr, pattern): + matched_enable = None + for entry in quant_cfg: + pattern = entry["quantizer_path"] + cfg = entry.get("cfg") + enable = entry.get("enable", True) + # Direct match: the bare quantizer_attr matches the whole pattern (e.g. "*weight_quantizer") + if fnmatch.fnmatch(quantizer_attr, pattern) or pattern.endswith(quantizer_attr): matched = cfg - return matched + matched_enable = enable + + return matched, matched_enable diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index cc5be9d564..d89ed35c6c 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -15,13 +15,11 @@ """This module provides a GEMM function for fp8 per tensor quantization.""" -from typing import Any - import torch from torch.autograd import Function from modelopt.torch.quantization.backends.gemm_registry import gemm_registry -from modelopt.torch.quantization.config import FP8_DEFAULT_CFG +from modelopt.torch.quantization.config import FP8_DEFAULT_CFG, find_quant_cfg_entry_by_path from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import FP8QTensor, QTensorWrapper from modelopt.torch.quantization.utils import reduce_amax @@ -99,9 +97,16 @@ def fp8_per_tensor_gemm(quant_module, input, bias=None): def _fp8_availability_check(module, input, args, kwargs): """Comprehensive check for FP8 GEMM availability.""" # Quantizer configs - quant_cfg: dict[str, Any] = FP8_DEFAULT_CFG["quant_cfg"] - input_cfg = quant_cfg["*input_quantizer"] - weight_cfg = quant_cfg["*weight_quantizer"] + quant_cfg_list: list = FP8_DEFAULT_CFG["quant_cfg"] + input_cfg = find_quant_cfg_entry_by_path(quant_cfg_list, "*input_quantizer").get("cfg", {}) + weight_cfg = find_quant_cfg_entry_by_path(quant_cfg_list, "*weight_quantizer").get("cfg", {}) + # cfg may be a list (SequentialQuantizer); fall back to the first element. + if isinstance(input_cfg, list): + input_cfg = input_cfg[0] + if isinstance(weight_cfg, list): + weight_cfg = weight_cfg[0] + if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + return False # Check hardware support if not torch.cuda.is_available() or not fp8_compatible(): diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index ffc18fea33..7734390168 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -15,8 +15,6 @@ """This module provides a GEMM function for nvfp4 quantization.""" -from typing import Any - import torch from torch.autograd import Function @@ -213,10 +211,21 @@ def _nvfp4_availability_check(module, input, args, kwargs): if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): return False - quant_cfg: dict[str, Any] = mtq.NVFP4_DEFAULT_CFG["quant_cfg"] + quant_cfg_list: list = mtq.NVFP4_DEFAULT_CFG["quant_cfg"] # Quantizer configs - input_cfg = quant_cfg["*input_quantizer"] - weight_cfg = quant_cfg["*weight_quantizer"] + input_cfg = mtq.config.find_quant_cfg_entry_by_path(quant_cfg_list, "*input_quantizer").get( + "cfg", {} + ) + weight_cfg = mtq.config.find_quant_cfg_entry_by_path(quant_cfg_list, "*weight_quantizer").get( + "cfg", {} + ) + # cfg may be a list (SequentialQuantizer); fall back to the first element. + if isinstance(input_cfg, list): + input_cfg = input_cfg[0] + if isinstance(weight_cfg, list): + weight_cfg = weight_cfg[0] + if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + return False # Check input quantizer config for key, value in input_cfg.items(): diff --git a/modelopt/torch/quantization/compress.py b/modelopt/torch/quantization/compress.py index 5477d0b61e..2a5cbbee9f 100644 --- a/modelopt/torch/quantization/compress.py +++ b/modelopt/torch/quantization/compress.py @@ -30,7 +30,7 @@ from .backends.gemm_registry import disable_real_quant_gemm, enable_real_quant_gemm from .config import CompressCfgType, CompressConfig -from .conversion import _replace_quant_module, set_quantizer_attribute +from .conversion import _replace_quant_module, set_quantizer_attributes_partial from .nn.modules.quant_linear import RealQuantLinear from .qtensor import QTensorWrapper, pack_real_quantize_weight from .utils import is_quantized_linear @@ -87,7 +87,7 @@ def compress_convert( compress_cfg = config.compress if "default" in compress_cfg and isinstance(compress_cfg["default"], bool): - set_quantizer_attribute( + set_quantizer_attributes_partial( model, "*weight_quantizer*", {"fake_quant": not compress_cfg["default"]} ) @@ -99,7 +99,7 @@ def compress_convert( def filter_func(name): return fnmatch.fnmatch(name, pattern) and "weight_quantizer" in name - set_quantizer_attribute(model, filter_func, {"fake_quant": not to_compress}) + set_quantizer_attributes_partial(model, filter_func, {"fake_quant": not to_compress}) else: raise ValueError( f"Invalid compression configuration: {to_compress}, expected a boolean as value." diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4a..6e51f863ea 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -50,40 +50,52 @@ Quantization Configs ================================ -Quantization config is dictionary specifying the values for keys ``"quant_cfg"`` and -``"algorithm"``. The ``"quant_cfg"`` key specifies the quantization configurations. The -``"algorithm"`` key specifies the ``algorithm`` argument to -:meth:`calibrate `. Please see :class:`QuantizeConfig` -for the quantization config definition. - -'Quantization configurations' is a dictionary mapping wildcards or filter functions -to its 'quantizer attributes'. The wildcards or filter functions are matched -against the quantizer module names. The quantizer modules have names ending with -``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and -input quantization (or activation quantization) respectively. The quantizer modules are generally -instances of -:class:`TensorQuantizer `. -The quantizer attributes are defined by :class:`QuantizerAttributeConfig`. See :class:`QuantizerAttributeConfig` -for details on the quantizer attributes and their values. - -The key `"default"` from the quantization configuration dictionary is applied if no other wildcard or filter functions -match the quantizer module name. - -The quantizer attributes are applied in the order they are specified. For the missing attributes, the default attributes -as defined by :class:`QuantizerAttributeConfig` are used. - -Quantizer attributes can also be a list of dictionaries. In this case, the matched quantizer module -is replaced with a -:class:`SequentialQuantizer ` -module which is used to quantize a tensor in multiple formats sequentially. Each quantizer attribute -dictionary in the list specifies the quantization formats for each quantization step of the -sequential quantizer. For example, `SequentialQuantizer` is used in 'INT4 Weights, FP8 Activations' -quantization in which the weights are quantized in INT4 followed by FP8. - -In addition, the dictionary entries could also be pytorch module class names mapping the class specific -quantization configurations. The pytorch modules should have a quantized equivalent. - -To get the string representation of a module class, do: +Quantization config is a dictionary with two top-level keys: + +- ``"quant_cfg"``: an ordered list of :class:`QuantizerCfgEntry` dicts that specify which + quantizers to configure and how. +- ``"algorithm"``: the calibration algorithm passed to + :meth:`calibrate `. + +Please see :class:`QuantizeConfig` for the full config schema. + +``quant_cfg`` — Entry Format +----------------------------- + +Each entry in the ``quant_cfg`` list is a :class:`QuantizerCfgEntry` with the following fields: + +- ``quantizer_path`` *(required)*: a wildcard string matched against quantizer module names. + Quantizer modules are instances of + :class:`TensorQuantizer ` + and have names ending with ``weight_quantizer``, ``input_quantizer``, etc. +- ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent module is + of this PyTorch class (e.g. ``"nn.Linear"``). If omitted, all matching quantizers are targeted + regardless of their parent class. +- ``cfg`` *(optional)*: a dict of quantizer attributes as defined by + :class:`QuantizerAttributeConfig`, or a list of such dicts. When a list is given, the matched + :class:`TensorQuantizer ` + is replaced with a + :class:`SequentialQuantizer ` + that applies each format in sequence. This is used for example in W4A8 quantization where weights + are quantized first in INT4 and then in FP8. +- ``enable`` *(optional)*: toggles matched quantizers on (``True``) or off (``False``), + independently of ``cfg``. When ``cfg`` is present and ``enable`` is absent, the quantizer is + implicitly enabled. When ``enable`` is the only field (no ``cfg``), it only flips the on/off + state — all other attributes remain unchanged. + +``quant_cfg`` — Ordering and Precedence +----------------------------------------- + +Entries are applied **in list order**; later entries override earlier ones for any quantizer they +match. The recommended pattern is: + +1. Start with a deny-all entry ``{"quantizer_path": "*", "enable": False}`` (provided as + :data:`_base_disable_all`) to disable every quantizer by default. +2. Follow with format-specific entries that selectively enable and configure the desired quantizers. +3. Append :data:`_default_disabled_quantizer_cfg` to enforce standard exclusions (e.g. BatchNorm + layers, LM head, MoE routers). + +To get the string representation of a module class for use in ``parent_class``, do: .. code-block:: @@ -97,15 +109,17 @@ .. code-block:: MY_QUANT_CFG = { - "quant_cfg": { - # Quantizer wildcard strings mapping to quantizer attributes - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, + "quant_cfg": [ + # Deny all quantizers by default + {"quantizer_path": "*", "enable": False}, - # Module class names mapping to quantizer configurations - "nn.LeakyReLU": {"*input_quantizer": {"enable": False}}, + # Enable and configure weight and input quantizers + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - } + # Disable input quantizers specifically for LeakyReLU layers + {"quantizer_path": "*input_quantizer", "parent_class": "nn.LeakyReLU", "enable": False}, + ] } .. _example-quantization-configs: @@ -129,157 +143,250 @@ # Create custom config CUSTOM_INT4_AWQ_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG) - CUSTOM_INT4_AWQ_CFG["quant_cfg"]["*lm_head*"] = {"enable": False} + CUSTOM_INT4_AWQ_CFG["quant_cfg"].append({"quantizer_path": "*lm_head*", "enable": False}) # quantize model model = mtq.quantize(model, CUSTOM_INT4_AWQ_CFG, forward_loop) """ -from collections.abc import Callable -from typing import Literal +import copy +from typing import Any, Literal, cast from pydantic import ValidationInfo, field_validator, model_validator +from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.utils.network import ConstructorLike -_default_disabled_quantizer_cfg = { - "nn.BatchNorm1d": {"*": {"enable": False}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "nn.BatchNorm3d": {"*": {"enable": False}}, - "nn.LeakyReLU": {"*": {"enable": False}}, - "*lm_head*": {"enable": False}, - "*proj_out.*": {"enable": False}, # In Whisper model, lm_head has key name proj_out - "*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router - "*router*": {"enable": False}, # Skip the MOE router - "*mlp.gate.*": {"enable": False}, # Skip the MOE router - "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router - "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d - "*output_layer*": {"enable": False}, - "output.*": {"enable": False}, - "default": {"enable": False}, -} -_mamba_moe_disabled_quantizer_cfg = { - "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE - "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE - "*q_proj*": {"enable": False}, # Skip QKV Linear - "*k_proj*": {"enable": False}, # Skip QKV Linear - "*v_proj*": {"enable": False}, # Skip QKV Linear - "*o_proj*": {"enable": False}, # Skip QKV Output Projection -} +class QuantizerCfgEntry(TypedDict, total=False): + """A single entry in a ``quant_cfg`` list.""" + + quantizer_path: Required[str] # matched against quantizer module names + parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") + cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) + enable: bool | None # toggles matched quantizers on/off; independent of cfg + + +def find_quant_cfg_entry_by_path( + quant_cfg_list: list[QuantizerCfgEntry], quantizer_path: str +) -> QuantizerCfgEntry: + """Find the last entry in a ``quant_cfg`` list whose ``quantizer_path`` key equals the query. + + This performs an **exact string comparison** against the ``quantizer_path`` field of each + entry — it does *not* apply ``fnmatch`` pattern matching. For example, passing + ``"*input_quantizer"`` will only match entries whose ``quantizer_path`` is literally + ``"*input_quantizer"``, not entries with a different wildcard that would match the same + module names at apply time. + + Returns the *last* match because entries are applied in list order and later entries + override earlier ones, so the last match represents the effective configuration. + + Args: + quant_cfg_list: A list of :class:`QuantizerCfgEntry` dicts. + quantizer_path: The exact ``quantizer_path`` string to search for. + + Returns: + The last entry whose ``quantizer_path`` equals *quantizer_path*. + + Raises: + KeyError: If no entry with the given ``quantizer_path`` is found. + """ + result = None + for entry in quant_cfg_list: + if isinstance(entry, dict) and entry.get("quantizer_path") == quantizer_path: + result = entry + if result is None: + raise KeyError(f"No quant_cfg entry with quantizer_path={quantizer_path!r}") + return result + + +_base_disable_all: list[QuantizerCfgEntry] = [ + {"quantizer_path": "*", "enable": False}, +] + +_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ + {"parent_class": "nn.BatchNorm1d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.BatchNorm2d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.BatchNorm3d", "quantizer_path": "*", "enable": False}, + {"parent_class": "nn.LeakyReLU", "quantizer_path": "*", "enable": False}, + {"quantizer_path": "*lm_head*", "enable": False}, + { + "quantizer_path": "*proj_out.*", + "enable": False, + }, # In Whisper model, lm_head has key name proj_out + { + "quantizer_path": "*block_sparse_moe.gate*", + "enable": False, + }, # Skip the MOE router + {"quantizer_path": "*router*", "enable": False}, # Skip the MOE router + {"quantizer_path": "*mlp.gate.*", "enable": False}, # Skip the MOE router + { + "quantizer_path": "*mlp.shared_expert_gate.*", + "enable": False, + }, # Skip the MOE router + {"quantizer_path": "*linear_attn.conv1d*", "enable": False}, + {"quantizer_path": "*mixer.conv1d*", "enable": False}, # Skip mamba conv1d + {"quantizer_path": "*output_layer*", "enable": False}, + {"quantizer_path": "output.*", "enable": False}, +] + +_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ + {"quantizer_path": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE + {"quantizer_path": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE + {"quantizer_path": "*q_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*k_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*v_proj*", "enable": False}, # Skip QKV Linear + {"quantizer_path": "*o_proj*", "enable": False}, # Skip QKV Output Projection +] INT8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT8_SMOOTHQUANT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "smoothquant", } INT8_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } FP8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_FP8_AGGRESSIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_FP8_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_path": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_path": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], "algorithm": "max", } FP8_PER_CHANNEL_PER_TOKEN_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": 0}, - "*input_quantizer": { - "num_bits": (4, 3), - "type": "dynamic", - "block_sizes": {-1: None}, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "type": "dynamic", + "block_sizes": {-1: None}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } # FP8 2D blockwise fake quantization config for deepseek models FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 128, -2: 128}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 128, -2: 128}, + }, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128}, + }, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } INT4_AWQ_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, }, - "*input_quantizer": {"enable": False}, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, @@ -288,172 +395,211 @@ # W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization # for weights. This could change in the future W4A8_AWQ_BETA_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - "enable": True, - }, - { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, + { + "num_bits": (4, 3), + }, + ], + }, + { + "quantizer_path": "*input_quantizer", + "cfg": { "num_bits": (4, 3), - "enable": True, }, - ], - "*input_quantizer": { - "num_bits": (4, 3), - "enable": True, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "awq_lite", } MXFP8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*input_quantizer": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXFP6_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*input_quantizer": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXFP4_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } W4A8_MXFP4_FP8_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*input_quantizer": {"num_bits": (4, 3), "axis": None}, - **_default_disabled_quantizer_cfg, - }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } MXINT8_DEFAULT_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*input_quantizer": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } +# KV-cache configs are designed to be merged with a primary quantization config (e.g. +# FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both +# _base_disable_all and "algorithm" because these are provided by the primary config. FP8_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - "num_bits": (4, 3), - "enable": True, + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": {"num_bits": (4, 3)}, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } FP8_AFFINE_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - "num_bits": (4, 3), - "bias": {-2: None, -4: None, "type": "static"}, + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + "bias": {-2: None, -4: None, "type": "static"}, + }, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } -_nvfp4_quantizer = { +_nvfp4_cfg = { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, } -_nvfp4_quantizer_bs32 = { +_nvfp4_cfg_bs32 = { "num_bits": (2, 1), "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, } def _nvfp4_selective_quant_cfg( layer_patterns: list[str], *, - quantizer: dict = _nvfp4_quantizer, + quantizer: dict = _nvfp4_cfg, weight_only: bool = False, algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: dict[str, object] = {} + quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: - quant_cfg[f"{pattern}weight_quantizer"] = quantizer + # Deep-copy the quantizer dict so each config constant gets its own instance. + quant_cfg.append( + {"quantizer_path": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)} + ) if not weight_only: - quant_cfg[f"{pattern}input_quantizer"] = quantizer - quant_cfg.update(_default_disabled_quantizer_cfg) + quant_cfg.append( + {"quantizer_path": f"{pattern}input_quantizer", "cfg": copy.deepcopy(quantizer)} + ) + quant_cfg.extend(_default_disabled_quantizer_cfg) return {"quant_cfg": quant_cfg, "algorithm": algorithm} NVFP4_DEFAULT_CFG = _nvfp4_selective_quant_cfg(["*"]) NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, }, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + ], "algorithm": { "method": "mse", "fp8_scale_sweep": True, @@ -461,15 +607,18 @@ def _nvfp4_selective_quant_cfg( } NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, }, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - }, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + ], "algorithm": { "method": "local_hessian", "fp8_scale_sweep": True, @@ -477,27 +626,28 @@ def _nvfp4_selective_quant_cfg( } MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], "algorithm": "max", } MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - **_default_disabled_quantizer_cfg, - **_mamba_moe_disabled_quantizer_cfg, - "*mixer.in_proj*": {"enable": False}, # Skip mamba linear - "*mixer.out_proj*": {"enable": False}, # Skip mamba linear - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_path": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_path": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], "algorithm": "max", } - NVFP4_AWQ_LITE_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm="awq_lite") NVFP4_AWQ_CLIP_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm={"method": "awq_clip"}) @@ -506,64 +656,87 @@ def _nvfp4_selective_quant_cfg( ["*"], algorithm={"method": "awq_full", "alpha_step": 0.1} ) - +# See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". NVFP4_AFFINE_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": { - **_nvfp4_quantizer, - "bias": {-2: None, -4: None, "type": "static"}, + "quant_cfg": [ + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "bias": {-2: None, -4: None, "type": "static"}, + }, }, - "default": {"enable": False}, - }, - "algorithm": "max", + ] } NVFP4_KV_CFG = { - "quant_cfg": { - "*[kv]_bmm_quantizer": _nvfp4_quantizer, - "default": {"enable": False}, - }, - "algorithm": "max", + "quant_cfg": [ + {"quantizer_path": "*[kv]_bmm_quantizer", "cfg": _nvfp4_cfg}, + ] } # Moved from examples/diffusers/quantization/config.py to here NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": { - "*weight_quantizer": _nvfp4_quantizer, - "*input_quantizer": _nvfp4_quantizer, - "*output_quantizer": {"enable": False}, - "*q_bmm_quantizer": { - "num_bits": (4, 3), + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_path": "*input_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_path": "*output_quantizer", "enable": False}, + { + "quantizer_path": "*q_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*k_bmm_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*k_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*v_bmm_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*v_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "*softmax_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "*softmax_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "transformer_blocks*bmm2_output_quantizer": { - "num_bits": (4, 3), + { + "quantizer_path": "transformer_blocks*bmm2_output_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } +# See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". NVFP4_KV_ROTATE_CFG = { - "quant_cfg": { - "*q_bmm_quantizer": { + "quant_cfg": [ + { + # q_bmm is disabled but pre-configured with rotate=True so that downstream + # code can inspect the rotate flag even while the quantizer is off. + "quantizer_path": "*q_bmm_quantizer", + "cfg": { + "rotate": True, + }, "enable": False, - "rotate": True, }, - "*k_bmm_quantizer": { - **_nvfp4_quantizer, - "rotate": True, + { + "quantizer_path": "*k_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "rotate": True, + }, }, - "*v_bmm_quantizer": _nvfp4_quantizer, - }, + {"quantizer_path": "*v_bmm_quantizer", "cfg": _nvfp4_cfg}, + ], "algorithm": "max", } @@ -572,40 +745,50 @@ def _nvfp4_selective_quant_cfg( ) W4A8_NVFP4_FP8_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, + }, }, - "*input_quantizer": { - "num_bits": (4, 3), - "enable": True, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": "max", } MXFP4_MLP_WEIGHT_ONLY_CFG = { - "quant_cfg": { - "*mlp*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_path": "*mlp*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - "*block_sparse_moe*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - "enable": True, + { + "quantizer_path": "*block_sparse_moe*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - **_default_disabled_quantizer_cfg, - }, + *_default_disabled_quantizer_cfg, + ], "algorithm": None, } NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( - ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_quantizer_bs32, weight_only=True + ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True ) NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"]) NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"]) @@ -1346,23 +1529,131 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig): ) -QuantizeQuantCfgType = dict[ - str | Callable, - QuantizerAttributeConfig - | list[QuantizerAttributeConfig] - | dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]], -] +QuantizeQuantCfgType = list[QuantizerCfgEntry] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None +def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. + + Supports the following input forms: + + - A ``list`` of entries in any of the per-entry forms below. + - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) — each key/value pair is + converted to a single-key dict entry and then normalized. + + Per-entry forms (when input is a list): + + - New format: ``{"quantizer_path": ..., "enable": ..., "cfg": ...}`` — passed through. + - Legacy single-key format: ``{"": }`` — converted to new format. + - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted + to a new-format entry with ``parent_class`` set. + + **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither + ``cfg`` nor ``enable``. Concretely, the following are invalid: + + - An empty entry ``{}``. + - An entry with only ``quantizer_path`` and no other keys — the only effect would be an + implicit ``enable=True``, which must be stated explicitly. + + **Normalization** — after conversion and validation every entry is put into canonical form: + + - ``enable`` is set to ``True`` if not explicitly specified. + - ``cfg`` is set to ``None`` if not present in the entry. + + Every returned entry is therefore guaranteed to have the keys ``quantizer_path``, ``enable``, + and ``cfg`` (plus optionally ``parent_class``). + + Args: + v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. + + Returns: + A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + + Raises: + ValueError: If any entry has only ``quantizer_path`` with neither ``cfg`` nor ``enable``, + or if the entry format is not recognized. + """ + # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. + if isinstance(v, dict): + v = [{k: val} for k, val in v.items()] + + def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: + """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" + # Legacy "default" key was a catch-all applied as "*" in the old conversion code. + if key == "default": + key = "*" + + if isinstance(key, str) and key.startswith("nn."): + if not isinstance(value, dict): + raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") + # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. + entries: list[QuantizerCfgEntry] = [] + for q_path, sub_cfg in value.items(): + sub_cfg = dict(sub_cfg) + enable = sub_cfg.pop("enable", None) + cfg = sub_cfg or None + entry: QuantizerCfgEntry = { + "parent_class": key, + "quantizer_path": q_path, + "cfg": cfg, + } + if enable is not None: + entry["enable"] = enable + entries.append(entry) + return entries + else: + if isinstance(value, dict): + cfg = {k: val for k, val in value.items() if k != "enable"} or None + enable = value.get("enable") + else: + cfg = value + enable = None + entry = {"quantizer_path": key, "cfg": cfg} + if enable is not None: + entry["enable"] = enable + return [entry] + + result: list[QuantizerCfgEntry] = [] + for raw in v: + if isinstance(raw, dict) and "quantizer_path" in raw: + entries = [dict(raw)] # copy to avoid mutating caller's data + elif isinstance(raw, dict) and len(raw) == 1: + key, val = next(iter(raw.items())) + entries = [dict(e) for e in _dict_to_entry(key, val)] + elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): + # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. + entries = [] + for k, val in raw.items(): + entries.extend(dict(e) for e in _dict_to_entry(k, val)) + else: + raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") + + for entry in entries: + # Validate: must carry at least one instruction beyond the path selector. + if "cfg" not in entry and "enable" not in entry: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " + "or both. An entry with only 'quantizer_path' has no effect (implicit " + "enable=True is not allowed; set it explicitly)." + ) + + # Normalize: make enable and cfg always explicit. + entry.setdefault("enable", True) + entry.setdefault("cfg", None) + + result.append(cast("QuantizerCfgEntry", entry)) + return result + + class QuantizeConfig(ModeloptBaseConfig): """Default configuration for ``quantize`` mode.""" quant_cfg: QuantizeQuantCfgType = ModeloptField( - default={"default": {"num_bits": 8, "axis": None}}, + default=[{"quantizer_path": "*", "cfg": {"num_bits": 8, "axis": None}}], title="Quantization configuration", validate_default=True, ) @@ -1374,6 +1665,29 @@ class QuantizeConfig(ModeloptBaseConfig): validate_default=True, ) + @field_validator("quant_cfg", mode="before") + @classmethod + def normalize_quant_cfg(cls, v): + """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" + if not isinstance(v, (list, dict)): + return v + return normalize_quant_cfg_list(v) + + @field_validator("quant_cfg", mode="after") + @classmethod + def validate_quant_cfg_entries(cls, v): + """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" + qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) + for entry in v: + cfg = entry.get("cfg") + if cfg is None: + continue + cfgs = cfg if isinstance(cfg, list) else [cfg] + for c in cfgs: + if isinstance(c, dict) and qac_fields & set(c.keys()): + QuantizerAttributeConfig.model_validate(c) + return v + class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1404,24 +1718,26 @@ def need_calibration(config): return True def _not_dynamic(cfg): - return ( - cfg.get("enable", True) - and cfg.get("type", "") != "dynamic" - and cfg.get("*", {}).get("enable", True) - ) + return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" - for name, cfg in config.get("quant_cfg", {}).items(): + quant_cfg: list = config.get("quant_cfg") or [] + quant_cfg = normalize_quant_cfg_list(quant_cfg) + for entry in quant_cfg: + name = entry["quantizer_path"] + raw_cfg = entry.get("cfg") if "weight_quantizer" in name: # We don't calibrate weight quantizer continue - # quantization like W4A8 has a list of weight quantizers - if isinstance(cfg, list): - for _config in cfg: + # Sequential quantizers (e.g. W4A8) have a list of cfg dicts + if isinstance(raw_cfg, list): + for _config in raw_cfg: if _not_dynamic(_config): - print(f"{cfg}: True") return True - elif _not_dynamic(cfg): - print(f"{cfg}: True") + continue + cfg = dict(raw_cfg or {}) + if "enable" in entry: + cfg["enable"] = entry["enable"] + if _not_dynamic(cfg): return True return False diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 472252e1c7..6684db7de6 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -19,7 +19,7 @@ import warnings from collections.abc import Callable from contextlib import contextmanager -from typing import Any +from typing import Any, cast import torch.nn as nn @@ -33,6 +33,7 @@ QuantizeQuantCfgType, QuantizerAttributeConfig, _QuantizeExportConfig, + normalize_quant_cfg_list, ) from .nn import ( NVFP4StaticQuantizer, @@ -48,6 +49,8 @@ "register", "replace_quant_module", "set_quantizer_attribute", + "set_quantizer_attributes_full", + "set_quantizer_attributes_partial", "set_quantizer_by_cfg", "set_quantizer_by_cfg_context", "unregister", @@ -60,7 +63,7 @@ def convert_to_quantized_model(model: ModelLikeModule, config: QuantizeConfig) - model = model.init_modellike() if isinstance(model, ModelLikeModule) else model replace_quant_module(model, version=ModeloptStateManager(model).state_version) - set_quantizer_by_cfg(model, config.get("quant_cfg", {})) + set_quantizer_by_cfg(model, config.get("quant_cfg", [])) metadata = {} update_quantize_metadata(model, config, metadata) @@ -76,7 +79,7 @@ def convert_to_quantized_model_svdquant( model = model.init_modellike() if isinstance(model, ModelLikeModule) else model create_and_replace_svdquant_linear_on_the_fly(model) - set_quantizer_by_cfg(model, config.get("quant_cfg", {})) + set_quantizer_by_cfg(model, config.get("quant_cfg", [])) metadata = {} update_quantize_metadata(model, config, metadata) @@ -211,127 +214,330 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType | dict): - """Update the quantizer attributes based on the specified `quant_cfg`. +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): + """Apply a quantization config list to the quantizers in ``quant_model``. - `quant_cfg` is a dictionary mapping wildcards or filter functions - to its quantizer attributes which are defined in - :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>`. - The wildcards or filter functions are matched against the quantizer module names. - The specified quantizer attributes of the matched quantizer modules are set accordingly. - The key ``"default"`` is a special key that sets the quantizer attributes of all the quantizers for which - no other wildcard or filter functions match the quantizer module name. + ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` + dicts. Each entry has the following fields: - In addition, the dictionary entries could also be pytorch module class names mapping the class specific - quantization configuration. The pytorch modules should have a quantized equivalent. + - ``quantizer_path`` *(required)*: wildcard matched against quantizer module names via + :func:`fnmatch`. + - ``cfg`` *(optional)*: a dict of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` + fields, or a list of such dicts for sequential quantization. + - ``enable`` *(optional)*: ``True`` or ``False`` to toggle matched quantizers on or off. + When omitted but ``cfg`` is present, defaults to ``True``. Every entry must specify at + least one of ``cfg`` or ``enable`` — an entry with only ``quantizer_path`` is invalid. + - ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent + module is of this PyTorch class name. - See :meth:`set_quantizer_attribute ` - for more details. + **Ordering and atomicity:** entries are applied in list order; later entries override earlier + ones for any quantizer they match. Each entry with a ``cfg`` is a **complete replacement** — + unspecified attributes revert to their defaults rather than inheriting from a prior entry. + The typical pattern is to deny all first (``{"quantizer_path": "*", "enable": False}``), then + selectively enable and configure target quantizers in subsequent entries. + + **``enable`` and ``cfg`` are independent:** + + - An entry with ``cfg`` (and optionally ``enable``) fully replaces the matched quantizer's + attributes. If ``enable`` is omitted, the quantizer is implicitly enabled. + - ``{"enable": False}`` without ``cfg`` **only** toggles the matched quantizers off, leaving + all other attributes unchanged. + - ``{"enable": True}`` without ``cfg`` **only** toggles the matched quantizers on, using + whatever attributes they currently have (or their defaults if never configured). + + See :ref:`quant-cfg` for the full format reference and common patterns. """ - quant_cfg = quant_cfg.copy() - if "default" in quant_cfg: - set_quantizer_attribute(quant_model, "*", quant_cfg["default"]) - quant_cfg.pop("default") - - for pattern, cfg in quant_cfg.items(): - if str(pattern) in QuantModuleRegistry: - parent_class = QuantModuleRegistry[str(pattern)] - assert isinstance(cfg, dict), ( - f"Expected a dictionary for quantizer configuration for child tensor quantizers of {parent_class}." + quant_cfg = normalize_quant_cfg_list(quant_cfg) + + for entry in quant_cfg: + quantizer_path: str = entry["quantizer_path"] + cfg = entry["cfg"] # None, dict, or list — always explicit after normalization + enable: bool = entry["enable"] # always explicit after normalization + parent_class_name = entry.get("parent_class") + if parent_class_name: + try: + parent_class = QuantModuleRegistry[parent_class_name] + except KeyError: + raise ValueError( + f"parent_class {parent_class_name!r} not found in QuantModuleRegistry. " + "Make sure the class has a registered quantized equivalent." + ) from None + else: + parent_class = None + + if cfg is None: + # No cfg: only toggle the enable state, leave all other attributes unchanged. + set_quantizer_attributes_partial( + quant_model, quantizer_path, {"enable": enable}, parent_class ) - for sub_pattern, sub_cfg in cfg.items(): - set_quantizer_attribute(quant_model, sub_pattern, sub_cfg, parent_class) - continue - set_quantizer_attribute(quant_model, pattern, cfg) + else: + # Has cfg: apply full replacement with the explicit enable value. + if isinstance(cfg, QuantizerAttributeConfig): + attributes = cfg.model_copy(update={"enable": enable}) + elif isinstance(cfg, dict): + attributes = QuantizerAttributeConfig(**cfg, enable=enable) + else: + attributes = [ + c.model_copy(update={"enable": enable}) + if isinstance(c, QuantizerAttributeConfig) + else QuantizerAttributeConfig(**c, enable=enable) + for c in cfg + ] + set_quantizer_attributes_full(quant_model, quantizer_path, attributes, parent_class) -def set_quantizer_attribute( +def _match_quantizer( + wildcard_or_filter_func: str | Callable, + name: str, + module: nn.Module, + parent_class: type[nn.Module] | None, + full_model: nn.Module, +): + if not isinstance(module, (TensorQuantizer, SequentialQuantizer)): + return False + if isinstance(wildcard_or_filter_func, str): + if not fnmatch.fnmatch(name, wildcard_or_filter_func): + return False + elif callable(wildcard_or_filter_func): + if not wildcard_or_filter_func(name): + return False + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + + # Get the parent module of this quantizer. When name has no dots (root-level quantizer), + # ".".join([]) == "" and get_submodule("") returns the model itself (PyTorch convention). + return parent_class is None or isinstance( + full_model.get_submodule(".".join(name.split(".")[:-1])), parent_class + ) + + +def set_quantizer_attributes_full( quant_model: nn.Module, wildcard_or_filter_func: str | Callable, - attribute: QuantizerAttributeConfig - | list[QuantizerAttributeConfig] - | dict[ - str | Callable, - QuantizerAttributeConfig | list[QuantizerAttributeConfig], - ] - | dict - | list[dict], - parent_class: type | None = None, + attributes: QuantizerAttributeConfig | list[QuantizerAttributeConfig], + parent_class: type[nn.Module] | None = None, ): - """Finegrained adjustment of quantizer attribute by wildcard or filter function. + """Set quantizer attributes by wildcard or filter function, fully overwriting existing attributes. + + Unlike :func:`set_quantizer_attributes_partial`, this function requires a complete + :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` and **replaces** the + matched quantizer's attributes entirely rather than merging with existing ones. Args: - quant_model: A pytorch model - wildcard_or_filter_func: a wildcard string or a filter function. The wildcard string is matched - against the quantizer module names. The quantizer modules are - instances of + quant_model: A pytorch model. + wildcard_or_filter_func: A wildcard string or a filter function. The wildcard string is + matched against the quantizer module names. The quantizer modules are instances of :class:`TensorQuantizer `. - The filter function takes a quantized module name as input and returns ``True`` if the + The filter function takes a quantizer module name as input and returns ``True`` if the quantizer should be adjusted and ``False`` otherwise. - attribute: An instance of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` or an equivalent - dictionary or a list of these two types. - If ``attribute`` is a list, the matched + attributes: A :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` (or a + list of them) that **fully replaces** the matched quantizer's current attributes. All + fields of the config are applied — unspecified fields revert to their defaults. + If ``attributes`` is a list, the matched :class:`TensorQuantizer ` - modules will be replaced with :class:`SequentialQuantizer ` - modules having one quantizer for each attribute instance from the list. + modules will be replaced with + :class:`SequentialQuantizer ` + modules having one quantizer per attribute instance in the list. See :meth:`set_from_attribute_config() ` - for more details on the supported attributes and their types. - parent_class: (Optional) The parent class of the quantizer modules matching ``wildcard_or_filter_func`` which - should be adjusted. If ``None``, all the matching quantizer modules will be adjusted. + for details on supported attributes and their types. + parent_class: (Optional) Restrict matching to quantizers whose immediate parent module is + an instance of this class. If ``None``, all quantizers matching + ``wildcard_or_filter_func`` are adjusted. """ + if not isinstance(attributes, (QuantizerAttributeConfig, list)): + raise ValueError( + f"Invalid type for attributes: {type(attributes)}, " + "expected QuantizerAttributeConfig or list of QuantizerAttributeConfig." + ) + if isinstance(attributes, list) and not all( + isinstance(attr, QuantizerAttributeConfig) for attr in attributes + ): + raise ValueError( + "All elements in attributes list must be of type QuantizerAttributeConfig." + ) for name, module in quant_model.named_modules(): - if isinstance(module, (TensorQuantizer, SequentialQuantizer)): - if isinstance(wildcard_or_filter_func, str): - if not fnmatch.fnmatch(name, wildcard_or_filter_func): - continue - elif callable(wildcard_or_filter_func): - if not wildcard_or_filter_func(name): - continue + if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): + if isinstance(attributes, list): + if not isinstance(module, SequentialQuantizer): + parent_module = quant_model.get_submodule(name.rpartition(".")[0]) + module = SequentialQuantizer( + *(TensorQuantizer() for _ in range(len(attributes))) + ) + setattr(parent_module, name.split(".")[-1], module) + elif len(attributes) != len(module): + warnings.warn( + f"The number of attributes ({len(attributes)}) does not match the number of " + f"quantizers of {module} leading to partial assignment.", + ) + module.set_from_attribute_config(attributes) else: - raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") + if isinstance(module, SequentialQuantizer): + # Downgrade SequentialQuantizer back to TensorQuantizer when the + # new entry provides a single (non-list) config. + parent_module = quant_model.get_submodule(name.rpartition(".")[0]) + module = TensorQuantizer() + setattr(parent_module, name.split(".")[-1], module) + cast("TensorQuantizer", module).set_from_attribute_config(attributes) - if parent_class is not None and not isinstance( - quant_model.get_submodule(".".join(name.split(".")[:-1])), parent_class - ): - continue - if isinstance(attribute, list) and not isinstance(module, SequentialQuantizer): - parent_module = quant_model.get_submodule(name.rpartition(".")[0]) - module = SequentialQuantizer(*(TensorQuantizer() for _ in range(len(attribute)))) - setattr(parent_module, name.split(".")[-1], module) - elif isinstance(attribute, list) and len(attribute) != len(module): - warnings.warn( - f"The number of attributes ({len(attribute)}) does not match the number of " - f"quantizers of {module} leading to partial assignment.", - ) - module.set_from_attribute_config(attribute) +def set_quantizer_attributes_partial( + quant_model: nn.Module, + wildcard_or_filter_func: str | Callable, + partial_attributes: dict[str, Any] | list[dict[str, Any]], + parent_class: type[nn.Module] | None = None, +): + """Update a subset of quantizer attributes by wildcard or filter function, merging with existing attributes. + + Unlike :func:`set_quantizer_attributes_full`, this function accepts an arbitrary subset of + quantizer attributes as a plain ``dict`` and **merges** them into the matched quantizer's + current attributes, leaving unspecified attributes unchanged. + + Args: + quant_model: A pytorch model. + wildcard_or_filter_func: A wildcard string or a filter function. The wildcard string is + matched against the quantizer module names. The quantizer modules are instances of + :class:`TensorQuantizer `. + The filter function takes a quantizer module name as input and returns ``True`` if the + quantizer should be adjusted and ``False`` otherwise. + partial_attributes: A ``dict`` (or a list of ``dict``) containing only the attributes to + update. Keys must be valid fields of + :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>`. Only the + specified keys are written; all other attributes on the quantizer remain unchanged. + When a ``dict`` is passed and the matched module is a + :class:`SequentialQuantizer `, + the dict is broadcast to every sub-quantizer. + When a ``list`` is passed, the matched module must already be a + :class:`SequentialQuantizer ` — + unlike :func:`set_quantizer_attributes_full`, this function will **not** replace a + :class:`TensorQuantizer ` with a + ``SequentialQuantizer``. + See + :meth:`set_from_attribute_config() ` + for details on supported attributes and their types. + parent_class: (Optional) Restrict matching to quantizers whose immediate parent module is + an instance of this class. If ``None``, all quantizers matching + ``wildcard_or_filter_func`` are adjusted. + """ + if not isinstance(partial_attributes, (dict, list)): + raise ValueError( + f"Invalid type for attributes: {type(partial_attributes)}, expected dictionary or list of dict." + ) + if isinstance(partial_attributes, list) and not all( + isinstance(attr, dict) for attr in partial_attributes + ): + raise ValueError("All elements in attributes list must be of type dict.") + + for name, module in quant_model.named_modules(): + if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): + module = cast("TensorQuantizer | SequentialQuantizer", module) # for type checker + if isinstance(partial_attributes, list): + if not isinstance(module, SequentialQuantizer): + raise ValueError( + f"Attributes is a list but {module} is not a SequentialQuantizer." + ) + module.set_from_attribute_config(partial_attributes) + elif isinstance(module, SequentialQuantizer): + # Broadcast the dict to all sub-quantizers. + module.set_from_attribute_config([partial_attributes] * len(module)) + else: + module.set_from_attribute_config(partial_attributes) @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType | dict): - """Context manager for setting quantizer attributes using `quant_cfg`. +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): + """Context manager that temporarily applies a quantization config and restores the original state on exit. - The set attributes will be reset to the original attributes after exiting the context manager. - See :meth:`set_quantizer_by_cfg` for more details. + Calls :func:`set_quantizer_by_cfg` on entry and reverts every + :class:`TensorQuantizer ` in + ``quant_model`` to its original attributes on exit. + + .. caution:: + Changing stateful attributes such as ``calibrator`` inside this context may produce + unexpected behavior because those objects are not deep-copied during save/restore. + + Args: + quant_model: A quantized PyTorch model whose quantizers will be temporarily reconfigured. + quant_cfg: A quantization config (or list of + :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` dicts) passed directly to + :func:`set_quantizer_by_cfg`. Sequential ``cfg`` lists are not allowed. - Use this context manager with caution. Changing certain attributes of the quantizer such as - `calibrator` can lead to unexpected behavior. + Yields: + None — the context body runs with the new quantizer attributes active. """ - assert not any(cfg for cfg in quant_cfg.values() if isinstance(cfg, (list, tuple))), ( - "list of config not support." - ) + quant_cfg = normalize_quant_cfg_list(quant_cfg) + + for entry in quant_cfg: + if isinstance(entry.get("cfg"), list): + raise ValueError( + "Sequential cfg lists are not allowed in set_quantizer_by_cfg_context. " + "Use only single-dict cfg entries." + ) - original_attributes = {} + original_attributes: dict[str, dict] = {} + original_types: dict[str, type] = {} for name, module in quant_model.named_modules(): - if isinstance(module, TensorQuantizer): + if isinstance(module, SequentialQuantizer): + # SequentialQuantizer.get_modelopt_state does not support properties_only; + # save per-sub-quantizer state so we can fully reconstruct on restore. + original_attributes[name] = { + "is_sequential_quantizer": True, + "sub_states": [tq.get_modelopt_state(properties_only=True) for tq in module], + } + original_types[name] = SequentialQuantizer + elif isinstance(module, TensorQuantizer): original_attributes[name] = module.get_modelopt_state(properties_only=True) + original_types[name] = TensorQuantizer set_quantizer_by_cfg(quant_model, quant_cfg) yield - for name, module in quant_model.named_modules(): - if isinstance(module, TensorQuantizer): + + # Restore original quantizer types and attributes. If set_quantizer_by_cfg downgraded a + # SequentialQuantizer to a TensorQuantizer (or vice-versa), we need to re-create the + # original module type before restoring attributes. + for name, module in list(quant_model.named_modules()): + if name not in original_attributes: + continue + orig_type = original_types[name] + if orig_type is SequentialQuantizer and not isinstance(module, SequentialQuantizer): + # Restore the SequentialQuantizer that was downgraded + saved = original_attributes[name] + parent_name, _, attr_name = name.rpartition(".") + parent_module = quant_model.get_submodule(parent_name) if parent_name else quant_model + module = SequentialQuantizer(*(TensorQuantizer() for _ in saved["sub_states"])) + setattr(parent_module, attr_name, module) + for tq, sub_state in zip(module, saved["sub_states"]): + tq.set_from_modelopt_state(sub_state, properties_only=True) + elif orig_type is TensorQuantizer and not isinstance(module, TensorQuantizer): + parent_name, _, attr_name = name.rpartition(".") + parent_module = quant_model.get_submodule(parent_name) if parent_name else quant_model + module = TensorQuantizer() + setattr(parent_module, attr_name, module) + module.set_from_modelopt_state(original_attributes[name], properties_only=True) + elif orig_type is TensorQuantizer: module.set_from_modelopt_state(original_attributes[name], properties_only=True) + elif orig_type is SequentialQuantizer: + saved = original_attributes[name] + for tq, sub_state in zip(module, saved["sub_states"]): + tq.set_from_modelopt_state(sub_state, properties_only=True) + + +def set_quantizer_attribute( + quant_model: nn.Module, + wildcard_or_filter_func: str | Callable, + attribute: Any, + parent_class: type[nn.Module] | None = None, +): + """Deprecated: use :func:`set_quantizer_attributes_partial` instead.""" + warnings.warn( + "set_quantizer_attribute is deprecated, use set_quantizer_attributes_partial " + "or set_quantizer_attributes_full instead.", + DeprecationWarning, + stacklevel=2, + ) + return set_quantizer_attributes_partial( + quant_model, wildcard_or_filter_func, attribute, parent_class + ) def register(original_cls: nn.Module, quantized_cls: nn.Module): diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 89097fd32c..23cd2594ac 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1101,7 +1101,9 @@ def forward(self, input, *args, **kwargs): self.awq_lite.num_cache_steps += 1 self.awq_lite.num_tokens += input.numel() / input.shape[-1] if self.awq_lite.is_input_quantized: - with set_quantizer_by_cfg_context(self.input_quantizer, {"*": {"enable": True}}): + with set_quantizer_by_cfg_context( + self.input_quantizer, [{"quantizer_path": "*", "enable": True}] + ): max_calibrate(self.input_quantizer, lambda quantizer: quantizer(input), False) return out_actual diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 4aa1ff46b4..8e0dddd620 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -30,13 +30,15 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.quantization.config import QuantizeConfig -from modelopt.torch.quantization.conversion import set_quantizer_by_cfg +from modelopt.torch.quantization.conversion import ( + set_quantizer_attributes_partial, + set_quantizer_by_cfg, +) from modelopt.torch.utils import atomic_print from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .algorithms import get_auto_quantize_config as _get_auto_quantize_config from .config import QuantizeAlgoCfgType -from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg from .nn import QuantModule, TensorQuantizer from .utils import is_quantized @@ -159,13 +161,15 @@ def quantize( :class:`QuantizeConfig ` specifying the values for keys ``"quant_cfg"`` and ``"algorithm"``. It is basically a dictionary specifying the values for keys ``"quant_cfg"`` and ``"algorithm"``. - The ``"quant_cfg"`` key specifies the quantization configurations. + The ``"quant_cfg"`` key specifies the quantization configurations as an ordered list of + :class:`QuantizerCfgEntry ` dicts. The ``"algorithm"`` key specifies the ``algorithm`` argument to :meth:`calibrate `. - Quantization configurations is a dictionary mapping wildcards or filter functions - to its quantizer attributes. The wildcards or filter functions are matched - against the quantizer module names. The quantizer modules have names ending with + Each entry in the ``"quant_cfg"`` list has a ``"quantizer_path"`` wildcard matched + against quantizer module names, an optional ``"cfg"`` dict of quantizer attributes, + and an optional ``"enable"`` toggle. Entries are applied in list order; later entries + override earlier ones. The quantizer modules have names ending with ``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and input quantization (or activation quantization) respectively. The quantizer modules are instances of @@ -178,17 +182,15 @@ def quantize( .. code-block::python config = { - - "quant_cfg": { + "quant_cfg": [ + # Disable all quantizers by default + {"quantizer_path": "*", "enable": False}, # "num_bits" specifies the number of bits for quantization # "axis" specifies the axis for quantization - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": -1}, - - # Default quantization settings - "default": {"num_bits": 8, "axis": None}, - } - "algorithm": "max" + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": -1}}, + ], + "algorithm": "max", } See :ref:`Quantization Formats ` to learn more about the supported @@ -323,10 +325,13 @@ def auto_quantize( .. code-block:: python INT8_CUSTOM_QUANT_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + ], "algorithm": "smoothquant", } @@ -527,7 +532,7 @@ def forward_backward_step(model, batch) -> None: "checkpoint": checkpoint, } # Disable all quantizers; AutoQuantize will enable the needed ones - set_quantizer_by_cfg(model, {"*": {"enable": False}}) + set_quantizer_by_cfg(model, [{"quantizer_path": "*", "enable": False}]) searcher.search(model, constraints, config=search_config) # type: ignore[arg-type] return model, searcher.state_dict() @@ -574,12 +579,12 @@ def get_auto_quantize_config(search_state, constraints=None, verbose=False): def disable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): """Disable quantizer by wildcard or filter function.""" - set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": False}) + set_quantizer_attributes_partial(model, wildcard_or_filter_func, {"enable": False}) def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): """Enable quantizer by wildcard or filter function.""" - set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": True}) + set_quantizer_attributes_partial(model, wildcard_or_filter_func, {"enable": True}) @atomic_print diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ec2c3cfc55..3ff7401ec3 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -203,8 +203,8 @@ def __init__( # Optional quantizer cache for caching quantizer related encoding or tensors. self._quantizer_cache = None - def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict): - """Set quantizer attributes from attribute_dict. + def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict[str, Any]): + """Set quantizer attributes from attribute_cfg. The attributes are defined in :class:`QuantizerAttributeConfig `. @@ -218,12 +218,27 @@ def _calibrator_setter(val): calib_cls, args, kwargs = standardize_constructor_args(val) return calib_cls(*args, **kwargs) + def _axis_setter(val): + if getattr(self, "_calibrator", None) is not None: + self._calibrator._axis = val + return val + + def _block_sizes_setter(val): + if val is not None: + # block_sizes and axis are mutually exclusive; clear axis when block_sizes is set + setattr(self, "_axis", None) + if getattr(self, "_calibrator", None) is not None: + self._calibrator._axis = None + return val + # Some attributes need custom handling. # By default, attributes from config are mapped to a name ``f"_{attribute}"`` _custom_setters: dict[str, tuple[str, Callable]] = { "enable": ("_disabled", lambda val: val is False), "type": ("_dynamic", lambda val: val == "dynamic"), "calibrator": ("_calibrator", _calibrator_setter), + "axis": ("_axis", _axis_setter), + "block_sizes": ("_block_sizes", _block_sizes_setter), "backend": ("backend", lambda val: val), "backend_extra_args": ("backend_extra_args", lambda val: val or {}), "use_constant_amax": ("_use_constant_amax", lambda val: val), @@ -1408,10 +1423,7 @@ def get_modelopt_state(self) -> dict[str, Any]: return {"num_quantizers": len(self), "is_sequential_quantizer": True} def set_from_attribute_config( - self, - attributes: list[dict[str, Any] | QuantizerAttributeConfig] - | dict[str, Any] - | QuantizerAttributeConfig, + self, attributes: list[QuantizerAttributeConfig] | list[dict[str, Any]] ): """Set the attributes of contained quantizers from a list of attribute_dicts.""" if not isinstance(attributes, (list, tuple)): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4340b8dc1f..b9008a7029 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -27,6 +27,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate +from modelopt.torch.quantization.config import QuantizerCfgEntry from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -310,11 +311,15 @@ def calibrate_with_adapters(model, args): def disable_lora_quantizers_in_config(config, layers): """Turns off input, weight, and output quantizers for LoRA weights and LoRALinear layers in config.""" - config["quant_cfg"]["*lora*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*lora*", "enable": False}) for layer in layers: - config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False} - config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False} - config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": f"*{layer}.input_quantizer", "enable": False}) + config["quant_cfg"].append( + {"quantizer_path": f"*{layer}.weight_quantizer", "enable": False} + ) + config["quant_cfg"].append( + {"quantizer_path": f"*{layer}.output_quantizer", "enable": False} + ) return config @@ -823,13 +828,25 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): def update_quant_cfg_with_kv_cache_quant( - quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any] + quant_cfg: dict[str, Any], kv_cache_quant_cfg: list[QuantizerCfgEntry] ) -> dict[str, Any]: - """Update the quant_cfg with the kv cache quant_cfg.""" + """Update the quant_cfg with the kv cache quant_cfg. + + Args: + quant_cfg: The outer quantization config dict (with ``"quant_cfg"`` and ``"algorithm"`` keys). + kv_cache_quant_cfg: A list of :class:`QuantizerCfgEntry + ` dicts for KV cache quantization, + typically ``some_kv_cfg["quant_cfg"]``. + + Returns: + A deep copy of ``quant_cfg`` with the KV cache entries appended to ``quant_cfg["quant_cfg"]``. + """ # If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case quant_cfg = copy.deepcopy(quant_cfg) - quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg") or {"default": {"enable": False}} - quant_cfg["quant_cfg"].update(kv_cache_quant_cfg) + inner: list[QuantizerCfgEntry] = quant_cfg.get("quant_cfg") or [ + {"quantizer_path": "*", "enable": False} + ] + quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) # Set default algorithm for kv cache quantization if not provided. if not quant_cfg.get("algorithm"): diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index cdc2aed948..0255caf4ed 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -194,7 +194,7 @@ def set_sparse_attention_attribute( ): """Set sparse attention attributes for modules matching pattern. - Similar to quantization's set_quantizer_attribute. + Similar to quantization's set_quantizer_attributes_partial. Args: model: Model to configure diff --git a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml index 72630965bd..1024a60c16 100644 --- a/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml @@ -19,46 +19,49 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*input_quantizer': - num_bits: e4m3 - axis: - '*weight_quantizer': - num_bits: e4m3 - axis: - default: + - quantizer_path: '*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*input_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*weight_quantizer' + cfg: + num_bits: e4m3 + axis: + - quantizer_path: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*lm_head*': + - quantizer_path: '*lm_head*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*output_layer*': + - quantizer_path: '*output_layer*' enable: false - '*proj_out.*': + - quantizer_path: '*proj_out.*' enable: false - '*router*': + - quantizer_path: '*router*' enable: false - output.*: + - quantizer_path: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml index 73e84b1bce..524fb6d97f 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml @@ -19,54 +19,57 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*input_quantizer' enable: true - '*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml index 2f3d6718ea..351a4f8c67 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,68 +19,73 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*mlp.experts*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp.experts*input_quantizer' enable: true - '*mlp.experts*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*weight_quantizer' enable: true - '*block_sparse_moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*input_quantizer' enable: true - '*block_sparse_moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml index fd502e2c30..33fee0e3e4 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml @@ -19,68 +19,73 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*mlp*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*input_quantizer' enable: true - '*mlp*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*weight_quantizer' enable: true - '*block_sparse_moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*input_quantizer' enable: true - '*block_sparse_moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml index 4a19f874aa..29cb76bb50 100644 --- a/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml +++ b/modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml @@ -19,82 +19,89 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*mlp*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*mlp*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*input_quantizer' enable: true - '*mlp*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*weight_quantizer' enable: true - '*block_sparse_moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*block_sparse_moe*input_quantizer' enable: true - '*block_sparse_moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*o_proj*weight_quantizer' enable: true - '*o_proj*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*o_proj*input_quantizer' enable: true - '*o_proj*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - default: + cfg: + num_bits: e4m3 + - quantizer_path: '*block_sparse_moe.gate*' enable: false - '*block_sparse_moe.gate*': + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*mlp.gate.*' enable: false - '*mlp.gate.*': + - quantizer_path: '*mlp.shared_expert_gate.*' enable: false - '*mlp.shared_expert_gate.*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml index e70160e988..bf59ac1896 100644 --- a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml +++ b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml @@ -19,66 +19,71 @@ metadata: ptq_cfg: algorithm: max quant_cfg: - '*moe*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + - quantizer_path: '*' + enable: false + - quantizer_path: '*moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*moe*input_quantizer' enable: true - '*moe*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*weight_quantizer' enable: true - '*mlp*weight_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*mlp*input_quantizer' enable: true - '*mlp*input_quantizer': - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_path: '*[kv]_bmm_quantizer' enable: true - '*share_expert*': + cfg: + num_bits: e4m3 + - quantizer_path: '*share_expert*' enable: false - '*moe.gate.*': + - quantizer_path: '*moe.gate.*' enable: false - default: + - quantizer_path: '*linear_attn.conv1d*' enable: false - '*linear_attn.conv1d*': + - quantizer_path: '*lm_head*' enable: false - '*lm_head*': + - quantizer_path: '*mixer.conv1d*' enable: false - '*mixer.conv1d*': + - quantizer_path: '*output_layer*' enable: false - '*output_layer*': + - quantizer_path: '*proj_out.*' enable: false - '*proj_out.*': + - quantizer_path: '*router*' enable: false - '*router*': + - quantizer_path: 'output.*' enable: false - output.*: + - parent_class: 'nn.BatchNorm1d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_path: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_path: '*' enable: false - nn.BatchNorm1d: - '*': - enable: false - nn.BatchNorm2d: - '*': - enable: false - nn.BatchNorm3d: - '*': - enable: false - nn.LeakyReLU: - '*': - enable: false - '*[kv]_bmm_quantizer': - num_bits: e4m3 - enable: true diff --git a/tests/_test_utils/torch/export/utils.py b/tests/_test_utils/torch/export/utils.py index 8011eb72e2..e0867bad7e 100644 --- a/tests/_test_utils/torch/export/utils.py +++ b/tests/_test_utils/torch/export/utils.py @@ -85,162 +85,241 @@ def forward(self, x): # Quantization configs partial_fp8_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "default": {"num_bits": 8, "enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + ], "algorithm": "max", } partial_w4a8_config = { - "quant_cfg": { - "*.2.weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": (4, 3), "axis": None, "enable": True}, - ], - "*.2.input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"num_bits": 8, "enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3), "axis": None}, + ], + "enable": True, + }, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } partial_nvfp4_config = { - "quant_cfg": { - "*.1.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.1.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.1.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } partial_nvfp4_awq_config = { - "quant_cfg": { - "*.2.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.2.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.2.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*.1.weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": False, }, - "*.1.input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*.1.input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": False, }, - "default": {"enable": False}, - }, + ], "algorithm": "awq_lite", } partial_int4_awq_config = { - "quant_cfg": { - "*.2.weight_quantizer": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*.2.weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, "enable": True, }, - "*.2.input_quantizer": {"enable": False}, - "default": {"enable": False}, - }, + {"quantizer_path": "*.2.input_quantizer", "enable": False}, + ], "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, } partial_fp8_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } partial_int8_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*output_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } partial_nvfp4_kv_cache_config = { - "quant_cfg": { - "*.1.weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.1.input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*[kv]_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*.1.weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.1.input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } only_weight_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + ], "algorithm": "max", } only_input_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + ], "algorithm": "max", } only_output_quantizer_fp8_config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": False}, - "*output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": False, + }, + { + "quantizer_path": "*output_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + "enable": True, + }, + ], "algorithm": "max", } diff --git a/tests/_test_utils/torch/quantization/onnx_export.py b/tests/_test_utils/torch/quantization/onnx_export.py index 5c74e656cd..57ee92ad09 100644 --- a/tests/_test_utils/torch/quantization/onnx_export.py +++ b/tests/_test_utils/torch/quantization/onnx_export.py @@ -29,11 +29,11 @@ def onnx_export_tester(model, device, num_bits, per_channel_quantization, constant_folding, dtype): axis = 0 if per_channel_quantization else None config = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": num_bits, "axis": axis}, - "*input_quantizer": {"num_bits": num_bits}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": num_bits, "axis": axis}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": num_bits}}, + ], "algorithm": "max", } diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index ae56dd299d..0a347de583 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -29,25 +29,28 @@ from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to -INT4_AWQ_FULL_CFG = mtq.INT4_AWQ_CFG.copy() +INT4_AWQ_FULL_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG) INT4_AWQ_FULL_CFG["algorithm"] = "awq_full" -INT4_AWQ_CLIP_CFG = mtq.INT4_AWQ_CFG.copy() +INT4_AWQ_CLIP_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG) INT4_AWQ_CLIP_CFG["algorithm"] = "awq_clip" # SVDQuant test cfg -INT4_SVDQUANT_CFG = mtq.INT4_AWQ_CFG.copy() +INT4_SVDQUANT_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG) INT4_SVDQUANT_CFG["algorithm"] = {"method": "svdquant", "lowrank": 8} # SVDQuant test cfg -FP4_SVDQUANT_CFG = mtq.NVFP4_AWQ_LITE_CFG.copy() +FP4_SVDQUANT_CFG = copy.deepcopy(mtq.NVFP4_AWQ_LITE_CFG) FP4_SVDQUANT_CFG["algorithm"] = {"method": "svdquant", "lowrank": 8} def get_awq_config(algorithm="awq_lite", block_size=8): config = copy.deepcopy(mtq.INT4_AWQ_CFG) - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: block_size} + for entry in config["quant_cfg"]: + if entry["quantizer_path"] == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: block_size} + break if "algorithm" not in config or not isinstance(config["algorithm"], dict): config["algorithm"] = {} diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index 93d3e8ccb8..430d7ddf68 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -77,7 +77,7 @@ def test_kv_rotate(rotate_fp32): model = nn.Sequential(SDPAAttention()) mtq.replace_quant_module(model) - set_quantizer_by_cfg(model, {"*": {"enable": False}}) + set_quantizer_by_cfg(model, [{"quantizer_path": "*", "enable": False}]) dummy_input = SDPAAttention.get_input(device="cuda") output_ref = model(dummy_input) if rotate_fp32: @@ -86,11 +86,9 @@ def test_kv_rotate(rotate_fp32): rotate = True with set_quantizer_by_cfg_context( model, - { - "*[qk]_bmm_quantizer": { - "rotate": rotate, - }, - }, + [ + {"quantizer_path": "*[qk]_bmm_quantizer", "cfg": {"rotate": rotate}}, + ], ): output_test = model(dummy_input) assert torch.allclose(output_ref, output_test, atol=0.05) @@ -98,11 +96,9 @@ def test_kv_rotate(rotate_fp32): # Test the rotation is actually applied by turning on only one of the query, key quantizers with set_quantizer_by_cfg_context( model, - { - "*k_bmm_quantizer": { - "rotate": rotate, - }, - }, + [ + {"quantizer_path": "*k_bmm_quantizer", "cfg": {"rotate": rotate}}, + ], ): output_test1 = model(dummy_input) assert not torch.allclose(output_ref, output_test1, atol=0.05) diff --git a/tests/gpu/torch/quantization/test_quant_rnn_cuda.py b/tests/gpu/torch/quantization/test_quant_rnn_cuda.py index be40de8e50..8a245336f0 100644 --- a/tests/gpu/torch/quantization/test_quant_rnn_cuda.py +++ b/tests/gpu/torch/quantization/test_quant_rnn_cuda.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn -from modelopt.torch.quantization import set_quantizer_attribute +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry @@ -44,7 +44,7 @@ def test_no_quant_proj(original_cls, bidirectional, bias): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn((3, 2, 8), device="cuda") diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 3e9ff4256c..715e00149b 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -15,6 +15,8 @@ """High-level tests for quantization.""" +import copy + import pytest from _test_utils.torch.quantization.models import SimpleConv, SimpleConvLinear, SimpleLinear from _test_utils.torch.quantization.quantize_common import ( @@ -29,20 +31,26 @@ from modelopt.torch.quantization.extensions import get_cuda_ext_mx NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - }, + ], "algorithm": { "method": "mse", "step_size": 0.25, @@ -52,17 +60,18 @@ } NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "enable": False, - }, - }, + {"quantizer_path": "*input_quantizer", "enable": False}, + ], "algorithm": { "method": "mse", "fp8_scale_sweep": True, @@ -123,7 +132,10 @@ def test_quantize(model_cls, config): if config == mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 8, -2: 8} + config = copy.deepcopy(config) + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: 8, -2: 8} model = model_cls().cuda() calib_data = [model.get_input().cuda() for _ in range(8)] quantize_model_and_forward(model, config, calib_data) diff --git a/tests/gpu/torch/quantization/test_real_quantize_cuda.py b/tests/gpu/torch/quantization/test_real_quantize_cuda.py index 2c65128966..8afedea9ef 100644 --- a/tests/gpu/torch/quantization/test_real_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_real_quantize_cuda.py @@ -15,6 +15,7 @@ """High-level tests for real weight-only quantization.""" +import copy import fnmatch import pytest @@ -47,10 +48,14 @@ def test_real_quantize(model_cls, config): # update config to fit test cases if config == mtq.INT4_AWQ_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = { - -1: 16, - "scale_bits": 8, - } + config = copy.deepcopy(config) + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = { + -1: 16, + "scale_bits": 8, + } + break if model_cls is SimpleConv or model_cls is SimpleConvLinear: pytest.skip( "INT4_AWQ_CFG requires even number of elements on last dimension for weights." @@ -101,10 +106,14 @@ def test_save_restore(model_cls, config): # update config to fit test cases if config == mtq.INT4_AWQ_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = { - -1: 16, - "scale_bits": 8, - } + config = copy.deepcopy(config) + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = { + -1: 16, + "scale_bits": 8, + } + break if model_cls is SimpleConv or model_cls is SimpleConvLinear: pytest.skip( "INT4_AWQ_CFG requires even number of elements on last dimension for weights." diff --git a/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py index b71eaeb219..cfa678b1a3 100644 --- a/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py +++ b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py @@ -33,23 +33,32 @@ from modelopt.torch.utils.plugins import megatron_prefill NVFP4_DEFAULT_CONFIG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, + { + "quantizer_path": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + }, "enable": True, }, - "*output_quantizer": {"enable": False}, - "*output_layer*": {"enable": False}, # Note: only output_layer is disabled. - "default": {"enable": False}, - }, + {"quantizer_path": "*output_quantizer", "enable": False}, + { + "quantizer_path": "*output_layer*", + "enable": False, + }, # Note: only output_layer is disabled. + ], "algorithm": "max", } diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_apex.py b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py index 1c9bf1ec66..144c05f6d7 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_apex.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py @@ -84,15 +84,15 @@ def test_convert_apex_parallel_linear(distributed_setup_size_1): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = model_ref.get_dummy_input().cuda() out_1 = model_ref(x) out_2 = model_test(x) assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = RegularQuantModelForTP().cuda() model_ref.load_state_dict(model_test.state_dict()) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index d8ba6fbed7..8075ddc131 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -82,15 +82,15 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = model_ref.get_dummy_input().cuda() out_1 = model_ref(x) out_2 = model_test(x) assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = RegularQuantModelForTP().cuda() model_ref.load_state_dict(model_test.state_dict(), strict=False) @@ -304,7 +304,7 @@ def _test_sharded_state_dict( ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. - config["quant_cfg"]["*output_layer*"] = {"enable": False} + config["quant_cfg"].append({"quantizer_path": "*output_layer*", "enable": False}) if modelopt_version is not None: mto.conversion.__version__ = modelopt_version @@ -383,36 +383,44 @@ def _test_sharded_state_dict( mixed_precision_config = copy.deepcopy(mtq.W4A8_AWQ_BETA_CFG) -mixed_precision_config["quant_cfg"].update( - { - "*.1.*": {"enable": False}, - "*.2.*weight_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.2.*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.3.*weight_quantizer.0": {"num_bits": 8, "axis": 0}, - "*.3.*weight_quantizer.1": {"enable": False}, - "*.3.*input_quantizer": {"num_bits": 8, "axis": None}, - } +mixed_precision_config["quant_cfg"].extend( + [ + {"quantizer_path": "*.1.*", "enable": False}, + {"quantizer_path": "*.2.*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.2.*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + {"quantizer_path": "*.3.*weight_quantizer.0", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*.3.*weight_quantizer.1", "enable": False}, + {"quantizer_path": "*.3.*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ] ) mixed_block_size_config = copy.deepcopy(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) -mixed_block_size_config["quant_cfg"].update( - { - "*.1.*": {"enable": False}, - "*.2.*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 64}, "enable": True}, - "*.2.*input_quantizer": {"num_bits": (4, 3), "axis": None}, - "*.3.*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128, -2: 64}, "enable": True}, - "*.3.*input_quantizer": {"num_bits": 8, "axis": None}, - } +mixed_block_size_config["quant_cfg"].extend( + [ + {"quantizer_path": "*.1.*", "enable": False}, + { + "quantizer_path": "*.2.*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 64}}, + "enable": True, + }, + {"quantizer_path": "*.2.*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}}, + { + "quantizer_path": "*.3.*weight_quantizer", + "cfg": {"num_bits": 4, "block_sizes": {-1: 128, -2: 64}}, + "enable": True, + }, + {"quantizer_path": "*.3.*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ] ) # Combined NVFP4 GEMM + KV cache quantization config NVFP4_GEMM_KV_CFG = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) -NVFP4_GEMM_KV_CFG["quant_cfg"].update(mtq.NVFP4_KV_CFG["quant_cfg"]) +NVFP4_GEMM_KV_CFG["quant_cfg"].extend(mtq.NVFP4_KV_CFG["quant_cfg"]) # Combined FP8 GEMM + KV cache quantization config FP8_GEMM_KV_CFG = copy.deepcopy(mtq.FP8_DEFAULT_CFG) -FP8_GEMM_KV_CFG["quant_cfg"].update(mtq.FP8_KV_CFG["quant_cfg"]) +FP8_GEMM_KV_CFG["quant_cfg"].extend(mtq.FP8_KV_CFG["quant_cfg"]) @pytest.mark.parametrize( diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py index 288cc75193..3ef3171b8c 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import pytest import torch import torch.nn as nn @@ -73,7 +75,11 @@ def test_quantize(model_cls, config): if config == mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG: # reduce block sizes for simple testing models - config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 8, -2: 8} + config = copy.deepcopy(config) + for entry in config["quant_cfg"]: + if entry.get("quantizer_path") == "*weight_quantizer": + entry["cfg"]["block_sizes"] = {-1: 8, -2: 8} + break model = model_cls().cuda() calib_data = [model.get_input().cuda() for _ in range(1)] quantize_model_and_forward(model, config, calib_data) diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index e52617861d..251fc7fdc2 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -15,6 +15,8 @@ """Unit tests for modelopt.recipe.loader and modelopt.recipe.loader.load_config.""" +import re + import pytest from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType @@ -164,11 +166,11 @@ def test_load_recipe_dir(tmp_path): (tmp_path / "recipe.yml").write_text( "metadata:\n recipe_type: ptq\n description: Dir test.\n" ) - (tmp_path / "ptq_cfg.yml").write_text("algorithm: max\nquant_cfg: {}\n") + (tmp_path / "ptq_cfg.yml").write_text("algorithm: max\nquant_cfg: []\n") recipe = load_recipe(tmp_path) assert recipe.recipe_type == RecipeType.PTQ assert recipe.description == "Dir test." - assert recipe.ptq_cfg == {"algorithm": "max", "quant_cfg": {}} + assert recipe.ptq_cfg == {"algorithm": "max", "quant_cfg": []} def test_load_recipe_dir_missing_recipe_raises(tmp_path): @@ -200,13 +202,49 @@ def test_load_recipe_dir_missing_ptq_cfg_raises(tmp_path): ], ) def test_general_ptq_yaml_matches_config_dicts(yaml_path, model_cfg_name, kv_cfg_name): - """Each general/ptq YAML's merged quant_cfg matches the corresponding config.py dicts.""" + """Each general/ptq YAML's quant_cfg list matches the merged Python config dicts.""" + import json + import modelopt.torch.quantization.config as qcfg + from modelopt.torch.quantization.config import normalize_quant_cfg_list model_cfg = getattr(qcfg, model_cfg_name) kv_cfg = getattr(qcfg, kv_cfg_name) yaml_data = load_config(yaml_path) - ptq = yaml_data["ptq_cfg"] - assert {**model_cfg["quant_cfg"], **kv_cfg["quant_cfg"]} == ptq["quant_cfg"] - assert model_cfg["algorithm"] == ptq["algorithm"] + def _normalize_fpx(val): + """Normalize FPx representations to a canonical ``[E, M]`` list. + + Python configs may use tuple form ``(E, M)`` or string alias ``"eEmM"``; + YAML always uses the string form. Both are converted to ``[E, M]`` so the + comparison is representation-agnostic. + """ + if isinstance(val, str): + m = re.fullmatch(r"e(\d+)m(\d+)", val) + if m: + return [int(m.group(1)), int(m.group(2))] + if isinstance(val, tuple) and len(val) == 2 and all(isinstance(x, int) for x in val): + return list(val) + if isinstance(val, dict): + return {str(k): _normalize_fpx(v) for k, v in val.items()} + return val + + def _normalize_entries(raw_entries): + """Normalize a raw quant_cfg list to a canonical, JSON-serialisable form.""" + entries = normalize_quant_cfg_list(list(raw_entries)) + result = [] + for entry in entries: + e = {k: v for k, v in entry.items() if v is not None} + if "cfg" in e and e["cfg"] is not None: + e["cfg"] = _normalize_fpx(e["cfg"]) + result.append(e) + return result + + def _sort_key(entry): + return json.dumps(entry, sort_keys=True, default=str) + + python_entries = _normalize_entries(model_cfg["quant_cfg"] + kv_cfg["quant_cfg"]) + yaml_entries = _normalize_entries(yaml_data["ptq_cfg"]["quant_cfg"]) + + assert sorted(python_entries, key=_sort_key) == sorted(yaml_entries, key=_sort_key) + assert model_cfg["algorithm"] == yaml_data["ptq_cfg"]["algorithm"] diff --git a/tests/unit/torch/quantization/plugins/test_attention_quant.py b/tests/unit/torch/quantization/plugins/test_attention_quant.py index 9526f80ac6..302e394963 100644 --- a/tests/unit/torch/quantization/plugins/test_attention_quant.py +++ b/tests/unit/torch/quantization/plugins/test_attention_quant.py @@ -61,10 +61,10 @@ def forward(self, hidden_states, **kwargs): kv_cache_config = { - "quant_cfg": { - "*[kv]_bmm_quantizer": {"num_bits": 4, "enable": True}, - "*softmax_quantizer": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*[kv]_bmm_quantizer", "cfg": {"num_bits": 4}, "enable": True}, + {"quantizer_path": "*softmax_quantizer", "enable": False}, + ], "algorithm": "max", } diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index 33730409a6..b9db122117 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -87,7 +87,7 @@ def test_convert_conv1d(): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = torch.randn(2, 3) out_1 = model_ref(x) @@ -95,8 +95,8 @@ def test_convert_conv1d(): assert torch.allclose(out_1, out_2) - mtq.set_quantizer_attribute(model_test, "*input_quantizer", {"enable": True}) - mtq.set_quantizer_attribute(model_test, "*weight_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True}) + mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True}) model_ref = PytorchModel() model_ref.load_state_dict(model_test.state_dict()) @@ -136,7 +136,7 @@ def test_dbrx(): expertglu_ref.w1, ) - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) x = torch.randn(1, 4, 32) out_1 = model_ref(x) @@ -193,7 +193,13 @@ def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config): tiny_llama_dir = create_tiny_llama_dir(tmp_path) # update config to fit test cases if quant_config == mtq.INT4_AWQ_CFG: - quant_config["quant_cfg"]["*weight_quantizer"]["block_sizes"] = {-1: 16} + import copy + + quant_config = copy.deepcopy(quant_config) + for entry in quant_config["quant_cfg"]: + if entry["quantizer_path"] == "*weight_quantizer": + entry.setdefault("cfg", {})["block_sizes"] = {-1: 16} + break else: raise ValueError(f"Unsupported quant_config: {quant_config}") diff --git a/tests/unit/torch/quantization/plugins/test_peft.py b/tests/unit/torch/quantization/plugins/test_peft.py index c794c67bc2..fda0e3bec4 100644 --- a/tests/unit/torch/quantization/plugins/test_peft.py +++ b/tests/unit/torch/quantization/plugins/test_peft.py @@ -48,7 +48,7 @@ def test_convert_loralinear(): assert hasattr(module, "weight_quantizer") assert hasattr(module, "output_quantizer") - mtq.set_quantizer_attribute(model_test, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False}) tf_output_tester(model_ref, model_test) diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index c0f049174e..d1a93a6261 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -28,7 +28,7 @@ QuantRecipeHparam, estimate_quant_compression, ) -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg +from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg from modelopt.torch.utils.distributed import DistributedProcessGroup @@ -110,11 +110,12 @@ def test_quant_recipe_hparam(): # use this config to test custom quantization config INT8_CUSTOM_QUANT_TEST_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - **_default_disabled_quantizer_cfg, - }, + "quant_cfg": [ + *_base_disable_all, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], "algorithm": "smoothquant", } @@ -230,14 +231,22 @@ def test_auto_quantize_disabled_layers_no_poison(): INT4INT8_AWQ_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": None, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": None}, + ], + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } @@ -480,7 +489,11 @@ def test_get_auto_quantize_config(method): # Use stored best recipe config = mtq.get_auto_quantize_config(search_state) assert "quant_cfg" in config - assert config["quant_cfg"]["*"] == {"enable": False} + assert isinstance(config["quant_cfg"], list) + assert any( + entry["quantizer_path"] == "*" and entry.get("enable") is False + for entry in config["quant_cfg"] + ) assert config["algorithm"] == "max" # Re-solve with different constraints diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py index 9a9a81a611..26aa7144a6 100644 --- a/tests/unit/torch/quantization/test_compute_quantization_mse.py +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -22,10 +22,10 @@ from modelopt.torch.quantization.nn import TensorQuantizer INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], "algorithm": "max", } diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 6ed0c918a8..71344e8091 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,9 @@ """Test of quantization config validations.""" +import pytest +from pydantic import ValidationError + from modelopt.torch.quantization.config import ( FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, FP8_DEFAULT_CFG, @@ -22,7 +25,10 @@ INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, + QuantizeConfig, + find_quant_cfg_entry_by_path, need_calibration, + normalize_quant_cfg_list, ) @@ -33,3 +39,435 @@ def test_need_calibration(): assert need_calibration(INT4_AWQ_CFG) assert need_calibration(W4A8_AWQ_BETA_CFG) assert need_calibration(NVFP4_DEFAULT_CFG) + + +def test_need_calibration_with_list_cfg(): + """need_calibration must handle sequential (list) cfg entries without crashing.""" + # Static list-cfg on a non-weight quantizer → needs calibration + cfg_static = { + "quant_cfg": [ + { + "quantizer_path": "*input_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, + ], + "enable": True, + }, + ], + "algorithm": "max", + } + assert need_calibration(cfg_static) + + # Dynamic list-cfg on a non-weight quantizer → no calibration needed + cfg_dynamic = { + "quant_cfg": [ + { + "quantizer_path": "*input_quantizer", + "cfg": [{"num_bits": (4, 3), "type": "dynamic"}], + "enable": True, + }, + ], + "algorithm": "max", + } + assert not need_calibration(cfg_dynamic) + + +class TestNormalizeQuantCfgList: + def test_new_format_passthrough(self): + """New-format entries are returned unchanged (only canonical defaults added).""" + raw = [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # defaulted + + def test_new_format_enable_false(self): + """Explicit enable=False is preserved.""" + raw = [{"quantizer_path": "*", "enable": False}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is False + assert result[0]["cfg"] is None # defaulted + + def test_new_format_explicit_enable_true_no_cfg(self): + """Explicit enable=True with no cfg is valid and cfg defaults to None.""" + raw = [{"quantizer_path": "*", "enable": True}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is True + assert result[0]["cfg"] is None + + def test_legacy_single_key_dict(self): + """Legacy {'*path': {attrs}} is converted to new format.""" + raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # defaulted + + def test_legacy_single_key_dict_with_enable(self): + """Legacy {'*path': {'enable': False}} splits enable out from cfg.""" + raw = [{"*input_quantizer": {"enable": False}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*input_quantizer" + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_legacy_nn_class_scoped(self): + """Legacy {'nn.Linear': {'*': {attrs}}} is converted with parent_class.""" + raw = [{"nn.Linear": {"*": {"enable": False}}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["parent_class"] == "nn.Linear" + assert result[0]["quantizer_path"] == "*" + assert result[0]["enable"] is False + + def test_normalization_cfg_defaults_to_none(self): + """Entries without cfg get cfg=None after normalization.""" + raw = [{"quantizer_path": "*lm_head*", "enable": False}] + result = normalize_quant_cfg_list(raw) + assert "cfg" in result[0] + assert result[0]["cfg"] is None + + def test_normalization_enable_defaults_to_true(self): + """Entries with cfg but no enable get enable=True after normalization.""" + raw = [{"quantizer_path": "*", "cfg": {"num_bits": 4}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["enable"] is True + + def test_empty_list(self): + """Empty list is returned unchanged.""" + assert normalize_quant_cfg_list([]) == [] + + def test_multiple_entries_order_preserved(self): + """The order of entries is preserved.""" + raw = [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}, + ] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*" + assert result[1]["quantizer_path"] == "*weight_quantizer" + + def test_error_on_quantizer_path_only(self): + """Entry with only quantizer_path and no cfg or enable is rejected.""" + with pytest.raises(ValueError, match="must specify 'cfg', 'enable'"): + normalize_quant_cfg_list([{"quantizer_path": "*"}]) + + def test_error_on_empty_dict(self): + """An empty dict entry is rejected.""" + with pytest.raises(ValueError): + normalize_quant_cfg_list([{}]) + + def test_error_on_multi_key_legacy_dict(self): + """A multi-key legacy dict (no quantizer_path, no nn.* keys) is rejected.""" + with pytest.raises(ValueError): + normalize_quant_cfg_list([{"*weight_quantizer": {}, "*input_quantizer": {}}]) + + def test_new_format_with_list_cfg(self): + """cfg can be a list of dicts for SequentialQuantizer.""" + raw = [ + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, + ], + } + ] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["cfg"] == raw[0]["cfg"] + assert result[0]["enable"] is True + + def test_legacy_flat_dict_conversion(self): + """Legacy flat dict {'*': {...}, '*weight_quantizer': {...}} is converted to list.""" + raw = {"*": {"enable": False}, "*weight_quantizer": {"num_bits": 8, "axis": 0}} + result = normalize_quant_cfg_list(raw) + assert len(result) == 2 + assert result[0]["quantizer_path"] == "*" + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + assert result[1]["quantizer_path"] == "*weight_quantizer" + assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[1]["enable"] is True + + def test_legacy_enable_only_produces_cfg_none(self): + """Legacy {'*': {'enable': False}} should produce cfg=None, not cfg={}.""" + raw = [{"*": {"enable": False}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["cfg"] is None + assert result[0]["enable"] is False + + def test_legacy_nn_class_enable_only_produces_cfg_none(self): + """Legacy nn.* scoped format with only enable produces cfg=None.""" + raw = [{"nn.Linear": {"*": {"enable": False}}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["cfg"] is None + assert result[0]["enable"] is False + assert result[0]["parent_class"] == "nn.Linear" + + def test_legacy_default_key(self): + """Legacy 'default' key is converted to quantizer_path='*'.""" + raw = [{"default": {"enable": False}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*" + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_legacy_default_key_with_cfg(self): + """Legacy 'default' key with cfg attributes maps to '*'.""" + raw = [{"default": {"num_bits": 8, "axis": None}}] + result = normalize_quant_cfg_list(raw) + assert result[0]["quantizer_path"] == "*" + assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert result[0]["enable"] is True + + def test_legacy_flat_dict_with_default_key(self): + """Legacy flat dict containing 'default' key converts it to '*'.""" + raw = {"default": {"enable": False}, "*weight_quantizer": {"num_bits": 8}} + result = normalize_quant_cfg_list(raw) + default_entries = [e for e in result if e["quantizer_path"] == "*"] + assert len(default_entries) == 1 + assert default_entries[0]["enable"] is False + + def test_legacy_nn_class_multi_key(self): + """Legacy nn.* scoped format with multiple sub-keys produces multiple entries.""" + raw = [ + { + "nn.Linear": { + "*input_quantizer": {"enable": False}, + "*weight_quantizer": {"num_bits": 4}, + } + } + ] + result = normalize_quant_cfg_list(raw) + assert len(result) == 2 + paths = {e["quantizer_path"] for e in result} + assert paths == {"*input_quantizer", "*weight_quantizer"} + for e in result: + assert e["parent_class"] == "nn.Linear" + + def test_legacy_nn_class_with_cfg(self): + """Legacy nn.* scoped format with actual quantizer attributes (not just enable).""" + raw = [{"nn.Linear": {"*weight_quantizer": {"num_bits": 4, "axis": 0}}}] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["parent_class"] == "nn.Linear" + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert result[0]["enable"] is True + + def test_legacy_list_valued_cfg(self): + """Legacy dict format with list-valued cfg (SequentialQuantizer) normalizes correctly.""" + raw = [ + { + "*weight_quantizer": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ] + } + ] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["quantizer_path"] == "*weight_quantizer" + assert isinstance(result[0]["cfg"], list) + assert len(result[0]["cfg"]) == 2 + assert result[0]["cfg"][0]["num_bits"] == 4 + assert result[0]["cfg"][1]["num_bits"] == 8 + assert result[0]["enable"] is True + + +class TestFindQuantCfgEntry: + def test_finds_last_match(self): + """When multiple entries share the same quantizer_path, returns the last one.""" + entries = normalize_quant_cfg_list( + [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 4}}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4}}, + ] + ) + result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") + assert result["cfg"] == {"num_bits": 4} + + def test_exact_match_only(self): + """Does not do fnmatch — only exact string equality on quantizer_path.""" + entries = normalize_quant_cfg_list( + [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + with pytest.raises(KeyError): + find_quant_cfg_entry_by_path(entries, "model.layer.weight_quantizer") + + def test_raises_on_missing(self): + """Raises KeyError when no entry matches.""" + entries = normalize_quant_cfg_list( + [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + with pytest.raises(KeyError): + find_quant_cfg_entry_by_path(entries, "*input_quantizer") + + def test_single_entry(self): + entries = normalize_quant_cfg_list([{"quantizer_path": "*", "enable": False}]) + result = find_quant_cfg_entry_by_path(entries, "*") + assert result["enable"] is False + + def test_empty_list(self): + with pytest.raises(KeyError): + find_quant_cfg_entry_by_path([], "*") + + +def test_need_calibration_with_legacy_dict_format(): + """need_calibration should accept legacy dict-format quant_cfg without crashing.""" + legacy_config = { + "quant_cfg": {"*input_quantizer": {"num_bits": 8, "axis": None}}, + "algorithm": "max", + } + assert need_calibration(legacy_config) + + +def test_need_calibration_with_legacy_list_of_single_key_dicts(): + """need_calibration should accept legacy list-of-single-key-dicts format.""" + legacy_config = { + "quant_cfg": [{"*input_quantizer": {"num_bits": 8, "axis": None}}], + "algorithm": "max", + } + assert need_calibration(legacy_config) + + +class TestMatchQuantizerCfg: + """Tests for _match_quantizer_cfg in algorithms.py.""" + + def test_wildcard_matches_bare_name(self): + """'*weight_quantizer' matches bare 'weight_quantizer'.""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") + assert matched == {"num_bits": 8} + assert enable is True + + def test_star_matches_any_bare_name(self): + """'*' matches any bare quantizer name.""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list([{"quantizer_path": "*", "enable": False}]) + matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") + assert matched is None # enable-only entry has cfg=None + assert enable is False + + def test_path_scoped_pattern_matches_matching_suffix(self): + """'*mlp*weight_quantizer' matches bare 'weight_quantizer' (suffix match).""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [{"quantizer_path": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] + ) + matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") + assert matched == {"num_bits": 4} + + def test_path_scoped_pattern_does_not_match_different_suffix(self): + """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [{"quantizer_path": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] + ) + matched, enable = _match_quantizer_cfg(quant_cfg, "input_quantizer") + assert matched is None + assert enable is None + + def test_last_match_wins(self): + """Later entries override earlier ones.""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4}}, + ] + ) + matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") + assert matched == {"num_bits": 4} + + def test_no_match_returns_none(self): + """No matching entry returns (None, None).""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [{"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + matched, enable = _match_quantizer_cfg(quant_cfg, "output_quantizer") + assert matched is None + assert enable is None + + def test_bracket_pattern_matches_correctly(self): + """'*[kv]_bmm_quantizer' matches 'k_bmm_quantizer' and 'v_bmm_quantizer'.""" + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [{"quantizer_path": "*[kv]_bmm_quantizer", "cfg": {"num_bits": (4, 3)}}] + ) + matched_k, _ = _match_quantizer_cfg(quant_cfg, "k_bmm_quantizer") + matched_v, _ = _match_quantizer_cfg(quant_cfg, "v_bmm_quantizer") + matched_w, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") + assert matched_k is not None + assert matched_v is not None + assert matched_w is None + + def test_path_scoped_does_not_overmatch(self): + """'*mixer*weight_quantizer' should NOT match 'input_quantizer'. + + Regression test: the old rsplit('*') logic would strip to 'weight_quantizer' and + overmatch any quantizer ending in 'weight_quantizer', but should not match unrelated names. + """ + from modelopt.torch.quantization.algorithms import _match_quantizer_cfg + + quant_cfg = normalize_quant_cfg_list( + [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*mixer*weight_quantizer", "cfg": {"num_bits": 4}}, + ] + ) + # input_quantizer should only match the disable-all, not the mixer pattern + matched, enable = _match_quantizer_cfg(quant_cfg, "input_quantizer") + assert matched is None # cfg is None (enable-only entry) + assert enable is False + + +class TestQuantizeConfigValidators: + """Tests for QuantizeConfig Pydantic field validators.""" + + def test_normalize_validator_converts_legacy_dict(self): + """The 'before' validator auto-normalizes legacy dict format.""" + cfg = QuantizeConfig( + quant_cfg={"*": {"enable": False}, "*weight_quantizer": {"num_bits": 8}}, + algorithm="max", + ) + assert isinstance(cfg.quant_cfg, list) + assert all("quantizer_path" in e for e in cfg.quant_cfg) + + def test_validate_quant_cfg_entries_catches_invalid_cfg(self): + """The 'after' validator surfaces QuantizerAttributeConfig errors early.""" + with pytest.raises(ValidationError): + QuantizeConfig( + quant_cfg=[ + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0, "block_sizes": {-1: 128}}, + } + ], + algorithm="max", + ) + + def test_validate_quant_cfg_entries_accepts_valid_cfg(self): + """The 'after' validator passes for valid configs.""" + cfg = QuantizeConfig( + quant_cfg=[ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "enable": False}, + ], + algorithm="max", + ) + assert len(cfg.quant_cfg) == 2 diff --git a/tests/unit/torch/quantization/test_custom_backend.py b/tests/unit/torch/quantization/test_custom_backend.py index f42d6a5f90..1b93085592 100644 --- a/tests/unit/torch/quantization/test_custom_backend.py +++ b/tests/unit/torch/quantization/test_custom_backend.py @@ -42,16 +42,19 @@ def dummy_backend(inputs: torch.Tensor, tq) -> torch.Tensor: model = torch.nn.Linear(16, 16, bias=False) cfg = { - "quant_cfg": { - "*weight_quantizer": { + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": { + "num_bits": 8, + "axis": None, + "backend": "dummy_backend", + "backend_extra_args": {"offset": 2.5}, + }, "enable": True, - "num_bits": 8, - "axis": None, - "backend": "dummy_backend", - "backend_extra_args": {"offset": 2.5}, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } @@ -88,10 +91,14 @@ def cached_backend(inputs: torch.Tensor, tq: TensorQuantizer) -> torch.Tensor: model = torch.nn.Linear(16, 16, bias=False) cfg = { - "quant_cfg": { - "*weight_quantizer": {"enable": True, "backend": "cached_backend"}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"backend": "cached_backend"}, + "enable": True, + }, + ], "algorithm": "max", } inputs = torch.randn(1, 16) diff --git a/tests/unit/torch/quantization/test_quant_activations.py b/tests/unit/torch/quantization/test_quant_activations.py index afc8decceb..e27b85bb6b 100644 --- a/tests/unit/torch/quantization/test_quant_activations.py +++ b/tests/unit/torch/quantization/test_quant_activations.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import set_quantizer_attributes_partial, tensor_quant from modelopt.torch.quantization.nn import QuantModuleRegistry @@ -42,7 +42,7 @@ def test_fake_quant_per_channel(self): negative_slope = 0.01 leaky_relu_object = nn.LeakyReLU(negative_slope=negative_slope) quant_leaky_relu_object = QuantModuleRegistry.convert(leaky_relu_object) - set_quantizer_attribute(quant_leaky_relu_object, lambda name: True, {"axis": (1)}) + set_quantizer_attributes_partial(quant_leaky_relu_object, lambda name: True, {"axis": (1)}) test_input = torch.randn(input_shape) quant_input = tensor_quant.fake_tensor_quant( diff --git a/tests/unit/torch/quantization/test_quant_batchnorm.py b/tests/unit/torch/quantization/test_quant_batchnorm.py index ee035dab13..c55b4b0b0e 100644 --- a/tests/unit/torch/quantization/test_quant_batchnorm.py +++ b/tests/unit/torch/quantization/test_quant_batchnorm.py @@ -20,7 +20,8 @@ import torch.nn as nn import torch.nn.functional as F -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry NUM_CHANNELS = 3 @@ -90,7 +91,7 @@ def test_fake_quant_per_tensor(self, original_cls, input_shape): def test_fake_quant_per_channel(self, original_cls, input_shape): batchnorm_object = original_cls(NUM_CHANNELS, affine=True) quant_batchnorm_object = QuantModuleRegistry.convert(batchnorm_object) - set_quantizer_attribute(quant_batchnorm_object, lambda name: True, {"axis": (1)}) + set_quantizer_attributes_partial(quant_batchnorm_object, lambda name: True, {"axis": (1)}) test_input = torch.randn(input_shape) reduce_dims = list(range(len(test_input.shape))) diff --git a/tests/unit/torch/quantization/test_quant_rnn.py b/tests/unit/torch/quantization/test_quant_rnn.py index 6f3d054c4e..0ea6d755a4 100644 --- a/tests/unit/torch/quantization/test_quant_rnn.py +++ b/tests/unit/torch/quantization/test_quant_rnn.py @@ -21,7 +21,8 @@ import torch import torch.nn as nn -from modelopt.torch.quantization import set_quantizer_attribute, tensor_quant +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.nn.modules.quant_rnn import VFRNNForward @@ -52,7 +53,7 @@ def test_no_quant(self, original_cls, bidirectional, bias): quant_rnn_object = QuantModuleRegistry.convert(rnn_object) rnn_object.eval() rnn_object_original.eval() - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) assert torch.allclose( quant_rnn_object.weight_ih_l0, rnn_object_original.weight_ih_l0, atol=1e-6 @@ -86,7 +87,7 @@ def test_no_quant_packed_sequence(self, original_cls, bidirectional, bias): quant_rnn_object = QuantModuleRegistry.convert(rnn_object) rnn_object.eval() rnn_object_original.eval() - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) assert torch.allclose( quant_rnn_object.weight_ih_l0, rnn_object_original.weight_ih_l0, atol=1e-6 @@ -124,7 +125,7 @@ def test_no_quant_proj(self, original_cls, bidirectional, bias): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn(INPUT_SHAPE) @@ -150,7 +151,7 @@ def test_no_quant_batch_first(self, original_cls, bidirectional): rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"enable": False}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"enable": False}) test_input = torch.randn([INPUT_SHAPE[1], INPUT_SHAPE[0], INPUT_SHAPE[2]]) @@ -176,7 +177,7 @@ def test_fake_quant_per_tensor(self, original_cls, bidirectional): ) rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": None}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": None}) quant_rnn_object._disable_input_quantizers() for name, weight in rnn_object_original.named_parameters(): @@ -205,7 +206,7 @@ def test_fake_quant_per_channel(self, original_cls, bidirectional): rnn_object = original_cls(HIDDEN_SIZE, HIDDEN_SIZE, NUM_LAYERS, bidirectional=bidirectional) rnn_object_original = copy.deepcopy(rnn_object) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": (0)}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": (0)}) quant_rnn_object._disable_input_quantizers() for name, weight in rnn_object_original.named_parameters(): @@ -234,7 +235,7 @@ def test_input_quant_per_tensor(self, original_cls, bidirectional): HIDDEN_SIZE, HIDDEN_SIZE, NUM_LAYERS, bidirectional=bidirectional, bias=True ) quant_rnn_object = QuantModuleRegistry.convert(rnn_object) - set_quantizer_attribute(quant_rnn_object, lambda name: True, {"axis": None}) + set_quantizer_attributes_partial(quant_rnn_object, lambda name: True, {"axis": None}) quant_rnn_object._disable_weight_quantizers() num_directions = 2 if bidirectional else 1 diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 641eafd2ff..c2a52f479e 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -32,41 +32,54 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.calib import MaxCalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.conversion import set_quantizer_attributes_full +from modelopt.torch.quantization.nn.modules.tensor_quantizer import ( + SequentialQuantizer, + TensorQuantizer, +) # A test config with double-quant (using `SequentialQuantizers`) WINT4INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": 0, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "axis": None, "enable": True}, - }, + "quant_cfg": [ + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + "enable": True, + }, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + "enable": True, + }, + ], "algorithm": "awq_lite", } # Test configs for per channel MSE calibration INT8_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": {"num_bits": 8, "axis": 0}, - "*input_quantizer": {"num_bits": 8, "axis": None}, - }, + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], "algorithm": "mse", } STATIC_WEIGHT_DYNAMIC_ACTIVATION_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "axis": 0, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, }, # Per-channel quantization - "*input_quantizer": { - "num_bits": 8, - "axis": (0, 1), - "type": "dynamic", + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": (0, 1), "type": "dynamic"}, }, # Dynamic per-token quantization - "default": {"enable": False}, - }, + ], "algorithm": "max", } @@ -77,14 +90,17 @@ def compute_amax(self): quant_cfg_custom_calib = { - "quant_cfg": { - "*": { - "num_bits": 4, - "axis": None, + "quant_cfg": [ + { + "quantizer_path": "*", + "cfg": { + "num_bits": 4, + "axis": None, + "calibrator": (NewMaxCalibrator, (4, None, False)), + }, "enable": True, - "calibrator": (NewMaxCalibrator, (4, None, False)), } - }, + ], "algorithm": "max", } @@ -131,7 +147,9 @@ def test_save_restore(model_cls, quant_config): def test_quantize_invalid_cfg(): model = SimpleLinear() config_invalid = { - "quant_cfg": {"*": {"num_bits": 4, "axis": 0, "block_sizes": {-1: 128}}}, + "quant_cfg": [ + {"quantizer_path": "*", "cfg": {"num_bits": 4, "axis": 0, "block_sizes": {-1: 128}}} + ], "algorithm": "max", } with pytest.raises(ValidationError, match="axis must be None when block_sizes is not None."): @@ -170,12 +188,22 @@ def test_custom_calib_config(): def test_class_wise_config(): model = SimpleConvLinear() config = { - "quant_cfg": { - "nn.Linear": {"*": {"num_bits": 4, "axis": -1, "enable": True}}, - "nn.Conv2d": {"*": {"num_bits": 8, "enable": True}}, - "nn.BatchNorm2d": {"*": {"enable": False}}, - "*output_quantizer": {"num_bits": 8, "enable": True}, - }, + "quant_cfg": [ + { + "parent_class": "nn.Linear", + "quantizer_path": "*", + "cfg": {"num_bits": 4, "axis": -1}, + "enable": True, + }, + { + "parent_class": "nn.Conv2d", + "quantizer_path": "*", + "cfg": {"num_bits": 8}, + "enable": True, + }, + {"parent_class": "nn.BatchNorm2d", "quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "cfg": {"num_bits": 8}, "enable": True}, + ], "algorithm": "max", } @@ -222,33 +250,28 @@ def test_static_weight_dynamic_activations(): def test_block_sizes_axis_model(): REF_QUANT_CFG = { # noqa: N806 - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "axis": 0, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "axis": None, "type": "dynamic"}, }, - "*input_quantizer": { - "num_bits": 8, - "axis": None, - "type": "dynamic", - }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } QUANT_CFG = { # noqa: N806 - "quant_cfg": { - "*weight_quantizer": { - "num_bits": 8, - "block_sizes": {1: None}, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": {"num_bits": 8, "block_sizes": {1: None}}, }, - "*input_quantizer": { - "num_bits": 8, - "block_sizes": {0: None, 1: None}, - "type": "dynamic", + { + "quantizer_path": "*input_quantizer", + "cfg": {"num_bits": 8, "block_sizes": {0: None, 1: None}, "type": "dynamic"}, }, - "default": {"enable": False}, - }, + ], "algorithm": "max", } model_ref = SimpleLinear() @@ -283,3 +306,184 @@ def forward_loop(model): out2 = model(inputs) assert torch.allclose(out1, out2), "Re-quantization with same config should be idempotent" + + +class TestSetQuantizerAttributesFull: + """Tests for set_quantizer_attributes_full and its atomicity semantics.""" + + def _quantize(self, model): + return mtq.quantize(model, mtq.INT8_DEFAULT_CFG, lambda m: m(m.get_input())) + + def test_basic_full_replacement(self): + """set_quantizer_attributes_full replaces all attributes on matched quantizers.""" + model = self._quantize(SimpleLinear()) + attrs = QuantizerAttributeConfig(num_bits=4, axis=0) + set_quantizer_attributes_full(model, "*weight_quantizer", attrs) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert isinstance(module, TensorQuantizer) + assert module.num_bits == 4 + assert module.axis == 0 + + def test_atomicity_unset_fields_revert_to_defaults(self): + """A full replacement reverts unspecified fields to QuantizerAttributeConfig defaults.""" + model = self._quantize(SimpleLinear()) + # First configure with axis=0 (non-default) + set_quantizer_attributes_full( + model, "*weight_quantizer", QuantizerAttributeConfig(num_bits=8, axis=0) + ) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.axis == 0 + + # Now replace with only num_bits=4; axis should revert to default (None) + set_quantizer_attributes_full( + model, "*weight_quantizer", QuantizerAttributeConfig(num_bits=4) + ) + default_axis = QuantizerAttributeConfig().axis + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.num_bits == 4 + assert module.axis == default_axis + + def test_parent_class_filter(self): + """parent_class restricts which quantizers are affected.""" + model = self._quantize(SimpleConvLinear()) + # Only set num_bits=4 for quantizers inside nn.Linear modules + set_quantizer_attributes_full( + model, + "*weight_quantizer", + QuantizerAttributeConfig(num_bits=4), + parent_class=torch.nn.Linear, + ) + for name, module in model.named_modules(): + if not name.endswith("weight_quantizer"): + continue + parent_name = name.rpartition(".")[0] + parent = model.get_submodule(parent_name) + if isinstance(parent, torch.nn.Linear): + assert module.num_bits == 4 + else: + # Conv2d weight_quantizers should be unchanged (still 8-bit from INT8_DEFAULT_CFG) + assert module.num_bits == 8 + + def test_wildcard_no_match_is_noop(self): + """A wildcard that matches nothing silently does nothing.""" + model = self._quantize(SimpleLinear()) + # Record state before + bits_before = { + n: m.num_bits for n, m in model.named_modules() if isinstance(m, TensorQuantizer) + } + set_quantizer_attributes_full( + model, "*nonexistent_quantizer*", QuantizerAttributeConfig(num_bits=4) + ) + bits_after = { + n: m.num_bits for n, m in model.named_modules() if isinstance(m, TensorQuantizer) + } + assert bits_before == bits_after + + def test_invalid_attributes_type_raises(self): + """Passing a plain dict instead of QuantizerAttributeConfig raises ValueError.""" + model = self._quantize(SimpleLinear()) + with pytest.raises((ValueError, AttributeError)): + set_quantizer_attributes_full(model, "*weight_quantizer", {"num_bits": 4}) # type: ignore[arg-type] + + def test_list_attributes_creates_sequential_quantizer(self): + """A list of QuantizerAttributeConfig replaces TensorQuantizer with SequentialQuantizer.""" + model = self._quantize(SimpleLinear()) + attrs = [ + QuantizerAttributeConfig(num_bits=4, block_sizes={-1: 128}), + QuantizerAttributeConfig(num_bits=8, axis=0), + ] + set_quantizer_attributes_full(model, "*weight_quantizer", attrs) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert isinstance(module, SequentialQuantizer) + assert len(module) == 2 + + +def test_ordering_later_entry_overrides_earlier(): + """Later entries in quant_cfg override earlier ones for the same quantizer.""" + model = SimpleLinear() + config = { + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], + "algorithm": "max", + } + model = mtq.quantize(model, config, lambda m: m(m.get_input())) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.num_bits == 4, "Later entry (num_bits=4) should override earlier (8)" + if name.endswith("input_quantizer"): + assert module.num_bits == 8 + + +def test_enable_only_entry_preserves_attributes(): + """An enable-only entry toggles the quantizer without resetting its attributes.""" + model = SimpleLinear() + config = { + "quant_cfg": [ + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4, "axis": 0}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + # This enable-only entry should disable without resetting num_bits/axis + {"quantizer_path": "*weight_quantizer", "enable": False}, + ], + "algorithm": "max", + } + model = mtq.quantize(model, config, lambda m: m(m.get_input())) + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert not module.is_enabled, "weight_quantizer should be disabled" + assert module.num_bits == 4, "num_bits should be preserved by enable-only entry" + assert module.axis == 0, "axis should be preserved by enable-only entry" + + +def test_atomicity_later_cfg_entry_does_not_inherit_earlier(): + """When two cfg-bearing entries match the same quantizer, the second fully replaces the first.""" + model = SimpleLinear() + config = { + "quant_cfg": [ + # Entry 1: set axis=0 + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + # Entry 2: only set num_bits=4, no axis — axis should revert to default (None), not 0 + {"quantizer_path": "*weight_quantizer", "cfg": {"num_bits": 4}}, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + ], + "algorithm": "max", + } + model = mtq.quantize(model, config, lambda m: m(m.get_input())) + default_axis = QuantizerAttributeConfig().axis + for name, module in model.named_modules(): + if name.endswith("weight_quantizer"): + assert module.num_bits == 4 + assert module.axis == default_axis, ( + f"axis should revert to default ({default_axis}), not inherit 0 from earlier entry" + ) + + +def test_legacy_dict_format_end_to_end(): + """Old dict-format quant_cfg works end-to-end through mtq.quantize via normalization.""" + model = SimpleLinear() + # Old-style dict config with "default" key and wildcard keys + old_config = { + "quant_cfg": { + "default": {"enable": False}, + "*weight_quantizer": {"num_bits": 8, "axis": 0}, + "*input_quantizer": {"num_bits": 8, "axis": None}, + }, + "algorithm": "max", + } + model = mtq.quantize(model, old_config, lambda m: m(m.get_input())) + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + if name.endswith(("weight_quantizer", "input_quantizer")): + assert module.is_enabled + assert module.num_bits == 8 + elif name.endswith("output_quantizer"): + # "default" key → quantizer_path="*" with enable=False disables everything, + # but weight/input quantizers are re-enabled by subsequent entries. + # output_quantizer is NOT re-enabled so it stays disabled. + assert not module.is_enabled diff --git a/tests/unit/torch/quantization/test_quantize_replace.py b/tests/unit/torch/quantization/test_quantize_replace.py index 140da2b646..4b0f4edd2d 100644 --- a/tests/unit/torch/quantization/test_quantize_replace.py +++ b/tests/unit/torch/quantization/test_quantize_replace.py @@ -47,7 +47,7 @@ def test_quantize_replace(model_cls): assert not isinstance(module, nn.Conv2d) or _is_quantized_linear_conv(module) assert not isinstance(module, nn.Linear) or _is_quantized_linear_conv(module) - mtq.set_quantizer_attribute(model_atq, "*", {"enable": False}) + mtq.set_quantizer_attributes_partial(model_atq, "*", {"enable": False}) out_ref = model_ref(dummy_input) out_atq = model_atq(dummy_input) diff --git a/tests/unit/torch/quantization/test_tensor_quant_cpu.py b/tests/unit/torch/quantization/test_tensor_quant_cpu.py index d5c6479cd5..78a79bbcb4 100644 --- a/tests/unit/torch/quantization/test_tensor_quant_cpu.py +++ b/tests/unit/torch/quantization/test_tensor_quant_cpu.py @@ -89,14 +89,18 @@ def test_num_bits(self): WINT4INT8_CFG = { - "quant_cfg": { - "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": 8, "axis": 0, "enable": True}, - ], - "*input_quantizer": {"num_bits": 8, "enable": True}, - "default": {"enable": False}, - }, + "quant_cfg": [ + {"quantizer_path": "*", "enable": False}, + { + "quantizer_path": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + "enable": True, + }, + {"quantizer_path": "*input_quantizer", "cfg": {"num_bits": 8}, "enable": True}, + ], "algorithm": "awq_full", } @@ -109,10 +113,14 @@ def test_set_quantizer_cxt(): state_dict = model.state_dict() output_ref = model(inputs) - mtq.set_quantizer_by_cfg(model, {"*output_quantizer": {"enable": True}}) + mtq.set_quantizer_by_cfg(model, [{"quantizer_path": "*output_quantizer", "enable": True}]) with mtq.set_quantizer_by_cfg_context( - model, {"*": {"enable": False}, "*output_quantizer": {"enable": True}} + model, + [ + {"quantizer_path": "*", "enable": False}, + {"quantizer_path": "*output_quantizer", "enable": True}, + ], ): for name, module in model.named_modules(): if not isinstance(module, TensorQuantizer): @@ -123,7 +131,7 @@ def test_set_quantizer_cxt(): assert not module.is_enabled mtq.calibrate(model, "max", lambda model: model(inputs * 10)) - mtq.set_quantizer_by_cfg(model, {"*output_quantizer": {"enable": False}}) + mtq.set_quantizer_by_cfg(model, [{"quantizer_path": "*output_quantizer", "enable": False}]) output_test = model(inputs) assert torch.allclose(output_ref, output_test)