Skip to content
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ NVIDIA Model Optimizer Changelog

- ONNX Runtime dependency upgraded to 1.24 to solve missing graph outputs when using the TensorRT Execution Provider.

**Backward Breaking Changes**

- Default ``--kv_cache_qformat`` in ``hf_ptq.py`` changed from ``fp8`` to ``fp8_cast``. Existing scripts that rely on the default will now skip KV cache calibration and use a constant amax instead. To restore the previous calibrated behavior, explicitly pass ``--kv_cache_qformat fp8``.
- Removed KV cache scale clamping (``clamp_(min=1.0)``) in the HF checkpoint export path. Calibrated KV cache scales below 1.0 are now exported as-is. If you observe accuracy degradation with calibrated KV cache (``--kv_cache_qformat fp8`` or ``nvfp4``), consider using the casting methods (``fp8_cast`` or ``nvfp4_cast``) instead.

**New Features**

- Add ``fp8_cast`` and ``nvfp4_cast`` modes for ``--kv_cache_qformat`` in ``hf_ptq.py``. These use a constant amax (FP8 E4M3 max, 448.0) without data-driven calibration, since the downstream engine uses FP8 attention math for both FP8 and NVFP4 quantization. A new ``use_constant_amax`` field in :class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>` controls this behavior.
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
Expand Down
70 changes: 54 additions & 16 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import argparse
import copy
import random
import time
import warnings
Expand Down Expand Up @@ -74,6 +75,19 @@

RAND_SEED = 1234


def _set_kv_cache_constant_amax(quant_cfg: dict) -> None:
"""Set use_constant_amax on KV cache quantizers.

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
"""
if "*[kv]_bmm_quantizer" in quant_cfg:
quant_cfg["*[kv]_bmm_quantizer"] = {
**quant_cfg["*[kv]_bmm_quantizer"],
"use_constant_amax": True,
}


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
Expand All @@ -96,13 +110,18 @@

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8_cast": "FP8_KV_CFG",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4_cast": "NVFP4_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}

# Formats that use use_constant_amax (no calibration needed).
_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"}

mto.enable_huggingface_checkpointing()


Expand Down Expand Up @@ -300,22 +319,25 @@ def forward_step(model, batch):
)

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly calibrate for kv cache quantization
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
kv_cache_quant_cfg.pop("default") # keep other quantizers from auto_quantize

mtq.set_quantizer_by_cfg(
language_model,
quant_cfg=kv_cache_quant_cfg,
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
# Lets calibrate only the quantizers for kv cache quantization this time. Let's disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model, {"*": {"enable": False}, **kv_cache_quant_cfg}
):
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
kv_cache_quant_cfg.pop("default", None) # keep other quantizers from auto_quantize

if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)

mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model, {"*": {"enable": False}, **kv_cache_quant_cfg}
):
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
return language_model


Expand All @@ -341,6 +363,13 @@ def load_model(args: argparse.Namespace):
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)
# Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights
# builds the KV quantizers with use_constant_amax already set. In calibration_only mode
# mtq.calibrate() does not re-apply quant_cfg, so this must happen before
# init_quantized_weights runs.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

# Do not use real quant GEMM so the calibration can be more accurate.
with init_quantized_weights(
Expand Down Expand Up @@ -931,15 +960,18 @@ def quantize_main(
# These layers are typically speculative decoding layers that should be exported as-is
mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None)
if mtp_layer_prefixes:
import copy

quant_cfg = copy.deepcopy(quant_cfg)
for prefix in mtp_layer_prefixes:
# Add exclusion pattern for this MTP layer (e.g., "*layers.92*")
pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*"
quant_cfg["quant_cfg"][pattern] = {"enable": False}
print(f"Excluding MTP layer from quantization: {pattern}")

# Use constant amax for KV quantizers when a cast format is selected.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
Expand Down Expand Up @@ -1054,9 +1086,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--kv_cache_qformat",
required=False,
default="fp8",
default="fp8_cast",
choices=KV_QUANT_CFG_CHOICES.keys(),
help="Specify KV cache quantization format, default to fp8 if not provided",
help=(
"Specify KV cache quantization format. Default: fp8_cast. "
"Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range "
"without data-driven calibration. "
"Other formats (fp8, nvfp4, etc.) use data-driven calibration."
),
)
parser.add_argument(
"--export_fmt",
Expand Down Expand Up @@ -1169,6 +1206,7 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
if not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")

return args


Expand Down
21 changes: 9 additions & 12 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch.
# For FP8, we recommend default kv cache scaling factor to be 1.
if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8:
for i, factor in enumerate(scaling_factors):
if factor is None:
continue
if factor.item() > 0.5:
warn(
f"Warning: Large KV activation detected: {factor.item()}, "
Expand Down Expand Up @@ -512,23 +514,24 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None:
num_bits_list.append(quantizer_attr.num_bits)
is_affine &= hasattr(quantizer_attr, "_bias_value")

return _compute_kv_cache_dtype(num_bits_list)
return _compute_kv_cache_dtype(num_bits_list, is_affine)


def _compute_kv_cache_dtype(num_bits_list: list[int | tuple[int, int]]) -> str | None:
def _compute_kv_cache_dtype(
num_bits_list: list[int | tuple[int, int]], is_affine: bool = False
) -> str | None:
"""Returns the kv_cache dtype.

If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8,
otherwise returns None.

Args:
modules: The module or list of modules to inspect.
num_bits_list: The list of num_bits from quantizers.
is_affine: Whether the quantizers have bias (affine mode).

Returns:
The kv_cache dtype.
"""
is_affine = True

if (4, 3) in num_bits_list:
return KV_CACHE_FP8
elif 8 in num_bits_list:
Expand Down Expand Up @@ -1087,14 +1090,8 @@ def postprocess_state_dict(
# Warn if scale exceeds threshold
if quantization == KV_CACHE_FP8 and value.item() > 0.5:
logger.warning(
"Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. "
"Setting KV cache scaling factor to at least 1."
"Large KV activations detected. Quantized KV cache may lead to higher accuracy drop."
)

# Ensure scale is at least 1 for KV_CACHE_FP8
# We export real value for KV_CACHE_NVFP4
if quantization == KV_CACHE_FP8:
value.clamp_(min=1.0)
post_state_dict[prefix + new_suffix] = value
break

Expand Down
9 changes: 9 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,15 @@ def validate_calibrator(cls, v, info: ValidationInfo):
""",
)

use_constant_amax: bool = ModeloptField(
default=False,
title="Use constant amax for the quantizer.",
description="""If True, set the amax to FP8 E4M3 max (448.0) and skip calibration.
This is used for KV cache quantization where the downstream engine uses FP8 attention
math for both FP8 and NVFP4 quantization, so the amax is hardcoded to the FP8 range.
""",
)


class QuantizeAlgorithmConfig(ModeloptBaseConfig):
"""Calibration algorithm config base."""
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,12 @@ def enable_stats_collection(model: nn.Module):
"""Enable stats collection for all quantizers in the model."""
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None:
if module._use_constant_amax:
# use_constant_amax quantizers use a fixed amax and don't need calibration.
# Disable quantization during calibration so it doesn't affect other quantizers.
module.disable_quant()
continue
elif module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
Expand All @@ -719,6 +724,11 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg
if not isinstance(module, TensorQuantizer) or module._disabled:
continue

if module._use_constant_amax:
# Re-enable quantization for use_constant_amax quantizers disabled in enable_stats_collection.
module.enable_quant()
continue

cal = getattr(module, "_calibrator", None)
if cal and not getattr(module, "_dynamic", False):
if method in {"entropy"}:
Expand Down
4 changes: 4 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
if amax is not None:
self.amax = amax

self._use_constant_amax = False
self.set_from_attribute_config(quant_attribute_cfg)

self._if_quant = if_quant
Expand Down Expand Up @@ -225,6 +226,7 @@ def _calibrator_setter(val):
"calibrator": ("_calibrator", _calibrator_setter),
"backend": ("backend", lambda val: val),
"backend_extra_args": ("backend_extra_args", lambda val: val or {}),
"use_constant_amax": ("_use_constant_amax", lambda val: val),
}

for attribute, val in attribute_cfg.items():
Expand Down Expand Up @@ -613,6 +615,8 @@ def load_calib_bias(self, *args, **kwargs):

def _get_amax(self, inputs):
"""Get amax from buffer or compute it dynamically."""
if self._use_constant_amax:
return torch.tensor(torch.finfo(torch.float8_e4m3fn).max, device=inputs.device)
if hasattr(self, "_amax"):
amax = self._amax
else:
Expand Down
57 changes: 57 additions & 0 deletions tests/_test_utils/torch/quantization/tensor_quantizer_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,63 @@ def test_set_from_attribute_config(self):
tq.set_from_attribute_config({"enable": False})
assert tq._disabled

def test_use_constant_amax(self):
"""Test that use_constant_amax sets a fixed amax (FP8 E4M3 max) without calibration."""
x = torch.randn(4, 8).to(self.device)
fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448.0

tq = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, use_constant_amax=True))
tq.to(self.device)

# _use_constant_amax should be stored as a boolean attribute
assert tq._use_constant_amax is True

# _get_amax should return a tensor with FP8 E4M3 max and correct device
returned_amax = tq._get_amax(x)
assert returned_amax.item() == fp8_max
assert returned_amax.device == x.device

# Forward pass should use the constant amax
out = tq(x)
assert out.shape == x.shape

def test_use_constant_amax_skips_calibration(self):
"""Test that use_constant_amax quantizers are disabled during calibration and re-enabled after."""
import torch.nn as nn

from modelopt.torch.quantization.model_calib import (
enable_stats_collection,
finish_stats_collection,
)

# Build a small model with one use_constant_amax quantizer and one normal quantizer
model = nn.ModuleDict(
{
"tq_const": TensorQuantizer(
QuantizerAttributeConfig(num_bits=8, use_constant_amax=True)
),
"tq_calib": TensorQuantizer(QuantizerAttributeConfig(num_bits=8)),
}
).to(self.device)

enable_stats_collection(model)

# use_constant_amax quantizer: quant disabled during calibration, not in calib mode
assert not model["tq_const"]._disabled
assert not model["tq_const"]._if_calib
assert not model["tq_const"]._if_quant

# normal quantizer with a calibrator should be in calib mode (quant disabled)
assert not model["tq_calib"]._disabled
assert model["tq_calib"]._if_calib
assert not model["tq_calib"]._if_quant

finish_stats_collection(model)

# After finish, use_constant_amax quantizer is re-enabled
assert not model["tq_const"]._disabled
assert model["tq_const"]._if_quant

def test_modelopt_state(self):
# Test loading of amax from ref to test
tensor_quantizer_ref = TensorQuantizer(QuantizerAttributeConfig(num_bits=4), amax=10.0)
Expand Down
3 changes: 3 additions & 0 deletions tests/examples/llm_ptq/test_llm_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def test_ptq_whisper(self, command):
),
# kv_cache
PTQCommand(quant="nvfp4_awq", kv_cache_quant="nvfp4"),
PTQCommand(quant="fp8", kv_cache_quant="fp8_cast", min_sm=89),
PTQCommand(quant="fp8", kv_cache_quant="fp8", min_sm=89),
PTQCommand(quant="nvfp4", kv_cache_quant="nvfp4_cast"),
# autoquant_kv_cache
PTQCommand(
quant="nvfp4,fp8",
Expand Down
4 changes: 2 additions & 2 deletions tests/gpu/torch/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected):
KV_CACHE_FP8,
128.0,
{
"layer1.k_proj.k_scale": torch.tensor([1.0]),
"layer1.k_proj.k_scale": torch.tensor([0.001]),
"layer1.v_proj.v_scale": torch.tensor([2.0]),
"layer1.pre_quant_scale": torch.tensor([0.128]),
},
Expand All @@ -222,7 +222,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected):
KV_CACHE_FP8,
128.0,
{
"layer1.k_proj.k_scale": torch.tensor([1.0]),
"layer1.k_proj.k_scale": torch.tensor([0.001]),
"layer1.v_proj.v_scale": torch.tensor([2.0]),
},
),
Expand Down
Loading