diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ae94ef2ab3..ac14bd35ac 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,8 +8,14 @@ NVIDIA Model Optimizer Changelog - ONNX Runtime dependency upgraded to 1.24 to solve missing graph outputs when using the TensorRT Execution Provider. +**Backward Breaking Changes** + +- Default ``--kv_cache_qformat`` in ``hf_ptq.py`` changed from ``fp8`` to ``fp8_cast``. Existing scripts that rely on the default will now skip KV cache calibration and use a constant amax instead. To restore the previous calibrated behavior, explicitly pass ``--kv_cache_qformat fp8``. +- Removed KV cache scale clamping (``clamp_(min=1.0)``) in the HF checkpoint export path. Calibrated KV cache scales below 1.0 are now exported as-is. If you observe accuracy degradation with calibrated KV cache (``--kv_cache_qformat fp8`` or ``nvfp4``), consider using the casting methods (``fp8_cast`` or ``nvfp4_cast``) instead. + **New Features** +- Add ``fp8_cast`` and ``nvfp4_cast`` modes for ``--kv_cache_qformat`` in ``hf_ptq.py``. These use a constant amax (FP8 E4M3 max, 448.0) without data-driven calibration, since the downstream engine uses FP8 attention math for both FP8 and NVFP4 quantization. A new ``use_constant_amax`` field in :class:`QuantizerAttributeConfig ` controls this behavior. - User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. - ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. - Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index fd35a53f27..7a0e0b9496 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import copy import random import time import warnings @@ -74,6 +75,19 @@ RAND_SEED = 1234 + +def _set_kv_cache_constant_amax(quant_cfg: dict) -> None: + """Set use_constant_amax on KV cache quantizers. + + Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. + """ + if "*[kv]_bmm_quantizer" in quant_cfg: + quant_cfg["*[kv]_bmm_quantizer"] = { + **quant_cfg["*[kv]_bmm_quantizer"], + "use_constant_amax": True, + } + + QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, @@ -96,13 +110,18 @@ KV_QUANT_CFG_CHOICES = { "none": "none", + "fp8_cast": "FP8_KV_CFG", "fp8": "FP8_KV_CFG", "fp8_affine": "FP8_AFFINE_KV_CFG", + "nvfp4_cast": "NVFP4_KV_CFG", "nvfp4": "NVFP4_KV_CFG", "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", "nvfp4_rotate": "NVFP4_KV_ROTATE_CFG", } +# Formats that use use_constant_amax (no calibration needed). +_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"} + mto.enable_huggingface_checkpointing() @@ -300,22 +319,25 @@ def forward_step(model, batch): ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # We need to explicitly calibrate for kv cache quantization + # We need to explicitly set up KV cache quantization after auto_quantize enable_quant_kv_cache = args.kv_cache_qformat != "none" print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") if enable_quant_kv_cache: - kv_cache_quant_cfg = getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] - kv_cache_quant_cfg.pop("default") # keep other quantizers from auto_quantize - - mtq.set_quantizer_by_cfg( - language_model, - quant_cfg=kv_cache_quant_cfg, + kv_cache_quant_cfg = copy.deepcopy( + getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] ) - # Lets calibrate only the quantizers for kv cache quantization this time. Let's disable all others. - with mtq.set_quantizer_by_cfg_context( - language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} - ): - mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) + kv_cache_quant_cfg.pop("default", None) # keep other quantizers from auto_quantize + + if args.kv_cache_qformat in _KV_CAST_FORMATS: + _set_kv_cache_constant_amax(kv_cache_quant_cfg) + + mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) + if args.kv_cache_qformat not in _KV_CAST_FORMATS: + # Calibrate only the KV cache quantizers; disable all others. + with mtq.set_quantizer_by_cfg_context( + language_model, {"*": {"enable": False}, **kv_cache_quant_cfg} + ): + mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) return language_model @@ -341,6 +363,13 @@ def load_model(args: argparse.Namespace): quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], ) + # Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights + # builds the KV quantizers with use_constant_amax already set. In calibration_only mode + # mtq.calibrate() does not re-apply quant_cfg, so this must happen before + # init_quantized_weights runs. + if args.kv_cache_qformat in _KV_CAST_FORMATS: + quant_cfg = copy.deepcopy(quant_cfg) + _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) # Do not use real quant GEMM so the calibration can be more accurate. with init_quantized_weights( @@ -931,8 +960,6 @@ def quantize_main( # These layers are typically speculative decoding layers that should be exported as-is mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None) if mtp_layer_prefixes: - import copy - quant_cfg = copy.deepcopy(quant_cfg) for prefix in mtp_layer_prefixes: # Add exclusion pattern for this MTP layer (e.g., "*layers.92*") @@ -940,6 +967,11 @@ def quantize_main( quant_cfg["quant_cfg"][pattern] = {"enable": False} print(f"Excluding MTP layer from quantization: {pattern}") + # Use constant amax for KV quantizers when a cast format is selected. + if args.kv_cache_qformat in _KV_CAST_FORMATS: + quant_cfg = copy.deepcopy(quant_cfg) + _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, @@ -1054,9 +1086,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--kv_cache_qformat", required=False, - default="fp8", + default="fp8_cast", choices=KV_QUANT_CFG_CHOICES.keys(), - help="Specify KV cache quantization format, default to fp8 if not provided", + help=( + "Specify KV cache quantization format. Default: fp8_cast. " + "Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range " + "without data-driven calibration. " + "Other formats (fp8, nvfp4, etc.) use data-driven calibration." + ), ) parser.add_argument( "--export_fmt", @@ -1169,6 +1206,7 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() if not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") + return args diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 3ea0180a10..2fc32e4727 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -474,6 +474,8 @@ def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch. # For FP8, we recommend default kv cache scaling factor to be 1. if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8: for i, factor in enumerate(scaling_factors): + if factor is None: + continue if factor.item() > 0.5: warn( f"Warning: Large KV activation detected: {factor.item()}, " @@ -512,23 +514,24 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") - return _compute_kv_cache_dtype(num_bits_list) + return _compute_kv_cache_dtype(num_bits_list, is_affine) -def _compute_kv_cache_dtype(num_bits_list: list[int | tuple[int, int]]) -> str | None: +def _compute_kv_cache_dtype( + num_bits_list: list[int | tuple[int, int]], is_affine: bool = False +) -> str | None: """Returns the kv_cache dtype. If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, otherwise returns None. Args: - modules: The module or list of modules to inspect. + num_bits_list: The list of num_bits from quantizers. + is_affine: Whether the quantizers have bias (affine mode). Returns: The kv_cache dtype. """ - is_affine = True - if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -1087,14 +1090,8 @@ def postprocess_state_dict( # Warn if scale exceeds threshold if quantization == KV_CACHE_FP8 and value.item() > 0.5: logger.warning( - "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop. " - "Setting KV cache scaling factor to at least 1." + "Large KV activations detected. Quantized KV cache may lead to higher accuracy drop." ) - - # Ensure scale is at least 1 for KV_CACHE_FP8 - # We export real value for KV_CACHE_NVFP4 - if quantization == KV_CACHE_FP8: - value.clamp_(min=1.0) post_state_dict[prefix + new_suffix] = value break diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a9b3574c4d..0f98459369 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1030,6 +1030,15 @@ def validate_calibrator(cls, v, info: ValidationInfo): """, ) + use_constant_amax: bool = ModeloptField( + default=False, + title="Use constant amax for the quantizer.", + description="""If True, set the amax to FP8 E4M3 max (448.0) and skip calibration. + This is used for KV cache quantization where the downstream engine uses FP8 attention + math for both FP8 and NVFP4 quantization, so the amax is hardcoded to the FP8 range. + """, + ) + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d6..2771bd26ef 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -706,7 +706,12 @@ def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: - if module._calibrator is not None: + if module._use_constant_amax: + # use_constant_amax quantizers use a fixed amax and don't need calibration. + # Disable quantization during calibration so it doesn't affect other quantizers. + module.disable_quant() + continue + elif module._calibrator is not None: module.disable_quant() module.enable_calib() else: @@ -719,6 +724,11 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg if not isinstance(module, TensorQuantizer) or module._disabled: continue + if module._use_constant_amax: + # Re-enable quantization for use_constant_amax quantizers disabled in enable_stats_collection. + module.enable_quant() + continue + cal = getattr(module, "_calibrator", None) if cal and not getattr(module, "_dynamic", False): if method in {"entropy"}: diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 2caec25656..6482cb216d 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -188,6 +188,7 @@ def __init__( if amax is not None: self.amax = amax + self._use_constant_amax = False self.set_from_attribute_config(quant_attribute_cfg) self._if_quant = if_quant @@ -225,6 +226,7 @@ def _calibrator_setter(val): "calibrator": ("_calibrator", _calibrator_setter), "backend": ("backend", lambda val: val), "backend_extra_args": ("backend_extra_args", lambda val: val or {}), + "use_constant_amax": ("_use_constant_amax", lambda val: val), } for attribute, val in attribute_cfg.items(): @@ -613,6 +615,8 @@ def load_calib_bias(self, *args, **kwargs): def _get_amax(self, inputs): """Get amax from buffer or compute it dynamically.""" + if self._use_constant_amax: + return torch.tensor(torch.finfo(torch.float8_e4m3fn).max, device=inputs.device) if hasattr(self, "_amax"): amax = self._amax else: diff --git a/tests/_test_utils/torch/quantization/tensor_quantizer_common.py b/tests/_test_utils/torch/quantization/tensor_quantizer_common.py index f81a5f7386..ad2722dca6 100644 --- a/tests/_test_utils/torch/quantization/tensor_quantizer_common.py +++ b/tests/_test_utils/torch/quantization/tensor_quantizer_common.py @@ -209,6 +209,63 @@ def test_set_from_attribute_config(self): tq.set_from_attribute_config({"enable": False}) assert tq._disabled + def test_use_constant_amax(self): + """Test that use_constant_amax sets a fixed amax (FP8 E4M3 max) without calibration.""" + x = torch.randn(4, 8).to(self.device) + fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448.0 + + tq = TensorQuantizer(QuantizerAttributeConfig(num_bits=8, use_constant_amax=True)) + tq.to(self.device) + + # _use_constant_amax should be stored as a boolean attribute + assert tq._use_constant_amax is True + + # _get_amax should return a tensor with FP8 E4M3 max and correct device + returned_amax = tq._get_amax(x) + assert returned_amax.item() == fp8_max + assert returned_amax.device == x.device + + # Forward pass should use the constant amax + out = tq(x) + assert out.shape == x.shape + + def test_use_constant_amax_skips_calibration(self): + """Test that use_constant_amax quantizers are disabled during calibration and re-enabled after.""" + import torch.nn as nn + + from modelopt.torch.quantization.model_calib import ( + enable_stats_collection, + finish_stats_collection, + ) + + # Build a small model with one use_constant_amax quantizer and one normal quantizer + model = nn.ModuleDict( + { + "tq_const": TensorQuantizer( + QuantizerAttributeConfig(num_bits=8, use_constant_amax=True) + ), + "tq_calib": TensorQuantizer(QuantizerAttributeConfig(num_bits=8)), + } + ).to(self.device) + + enable_stats_collection(model) + + # use_constant_amax quantizer: quant disabled during calibration, not in calib mode + assert not model["tq_const"]._disabled + assert not model["tq_const"]._if_calib + assert not model["tq_const"]._if_quant + + # normal quantizer with a calibrator should be in calib mode (quant disabled) + assert not model["tq_calib"]._disabled + assert model["tq_calib"]._if_calib + assert not model["tq_calib"]._if_quant + + finish_stats_collection(model) + + # After finish, use_constant_amax quantizer is re-enabled + assert not model["tq_const"]._disabled + assert model["tq_const"]._if_quant + def test_modelopt_state(self): # Test loading of amax from ref to test tensor_quantizer_ref = TensorQuantizer(QuantizerAttributeConfig(num_bits=4), amax=10.0) diff --git a/tests/examples/llm_ptq/test_llm_ptq.py b/tests/examples/llm_ptq/test_llm_ptq.py index 4fc39f5ecb..f5d0b39c1d 100644 --- a/tests/examples/llm_ptq/test_llm_ptq.py +++ b/tests/examples/llm_ptq/test_llm_ptq.py @@ -98,6 +98,7 @@ def test_ptq_whisper(self, command): ), # kv_cache PTQCommand(quant="nvfp4_awq", kv_cache_quant="nvfp4"), + PTQCommand(quant="fp8", kv_cache_quant="fp8_cast", min_sm=89), # autoquant_kv_cache PTQCommand( quant="nvfp4,fp8", diff --git a/tests/gpu/torch/export/test_export.py b/tests/gpu/torch/export/test_export.py index 7d7637f844..55eee2c138 100644 --- a/tests/gpu/torch/export/test_export.py +++ b/tests/gpu/torch/export/test_export.py @@ -208,7 +208,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected): KV_CACHE_FP8, 128.0, { - "layer1.k_proj.k_scale": torch.tensor([1.0]), + "layer1.k_proj.k_scale": torch.tensor([0.001]), "layer1.v_proj.v_scale": torch.tensor([2.0]), "layer1.pre_quant_scale": torch.tensor([0.128]), }, @@ -222,7 +222,7 @@ def test_get_scaling_factor_from_weight(weight, group_size, expected): KV_CACHE_FP8, 128.0, { - "layer1.k_proj.k_scale": torch.tensor([1.0]), + "layer1.k_proj.k_scale": torch.tensor([0.001]), "layer1.v_proj.v_scale": torch.tensor([2.0]), }, ),