Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ NVIDIA Model Optimizer Changelog
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
- Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.

- Add support for block-granular RHT for non-power-of-2 dimensions.

**Misc**

- Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``.
Expand Down
32 changes: 23 additions & 9 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,26 @@
BiasMethod = Literal["mean", "max_min"]


class RotateConfig(ModeloptBaseConfig):
"""Configuration for rotating quantizer input via Hadamard transform (RHT/QuaRot/SpinQuant).

See :func:`normalized_hadamard_transform <modelopt.torch.quantization.nn.functional.normalized_hadamard_transform>`
for transform details.
"""

enable: bool = False
rotate_fp32: bool = False
block_size: int | None = None

@field_validator("block_size", mode="before")
@classmethod
def validate_block_size(cls, v):
"""Validate block_size is a positive int (mode=before to catch bool before int coercion)."""
if v is not None and (isinstance(v, bool) or not isinstance(v, int) or v <= 0):
raise ValueError(f"block_size must be a positive int, got {v!r}")
return v


class QuantizerAttributeConfig(ModeloptBaseConfig):
"""Quantizer attribute type."""

Expand Down Expand Up @@ -975,18 +995,12 @@ def validate_calibrator(cls, v, info: ValidationInfo):
assert v in ["max", "histogram"]
return v

rotate: bool | dict[str, bool] = ModeloptField(
rotate: bool | RotateConfig = ModeloptField(
default=False,
title="""Configuration for rotating the input before quantization.""",
description="""Can be a boolean or a dictionary with the following keys:
- "enable": Boolean to enable/disable rotation (default: False)
- "rotate_fp32": Boolean to compute rotation in float32 precision (default: False)

If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False.
description="""Can be a boolean or a :class:`RotateConfig` instance (or equivalent dict).

When enabled, the input of the quantizer will be rotated with a hadamard matrix
given by scipy.linalg.hadamard, i.e.
``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``.
If a boolean, it is treated as :attr:`RotateConfig.enable` with all other fields defaulting.

This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant.
See https://arxiv.org/abs/2404.00456 for example.""",
Expand Down
61 changes: 56 additions & 5 deletions modelopt/torch/quantization/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,29 @@ def backward(ctx, grad_outputs):
return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined]


def normalized_hadamard_transform(inputs, rotate_fp32=False):
"""Normalized fast hadamard transform."""
def _largest_pow2_divisor(n: int) -> int:
"""Return the largest power of 2 that divides n."""
return n & (-n)


def normalized_hadamard_transform(inputs, rotate_fp32=False, block_size=None):
"""Normalized fast hadamard transform.

Supports block-granular RHT for dimensions that are not a power of 2.
When block_size is used, the last dimension is split into blocks of size block_size
(must be power of 2), and Hadamard is applied per block. This enables RHT for
MoE expert channel dimensions (e.g. 1920, 1536, 896) that are not powers of 2.

Args:
inputs: Input tensor, Hadamard is applied along the last dimension.
rotate_fp32: If True, compute rotation in float32.
block_size: Block size for block-granular RHT. Must be power of 2 and divide
inputs.shape[-1]. If None: use full-dimension FHT when dim is power of 2;
otherwise auto-select the largest power-of-2 divisor of the dimension.

Returns:
Rotated tensor with same shape as inputs.
"""
global fast_hadamard_transform
try:
import fast_hadamard_transform
Expand All @@ -104,10 +125,40 @@ def normalized_hadamard_transform(inputs, rotate_fp32=False):
"`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`"
)

dim = inputs.shape[-1]
dtype = inputs.dtype
if rotate_fp32:
inputs = inputs.to(torch.float32)
outputs = FastHadamardTransform.apply(inputs) / torch.sqrt(
torch.tensor(inputs.shape[-1], dtype=torch.float32)
)

if block_size is None and utils.is_pow2(dim):
# Full-dimension FHT (original behavior)
outputs = FastHadamardTransform.apply(inputs) / torch.sqrt(
torch.tensor(dim, dtype=torch.float32)
)
else:
# Block-granular RHT
if block_size is None:
block_size = _largest_pow2_divisor(dim)
if block_size < 2:
raise RuntimeError(
f"Block RHT: dimension {dim} has no power-of-2 divisor >= 2. "
"Set rotate.block_size explicitly (e.g. 128) or use a dimension divisible by a power of 2."
)
if not utils.is_pow2(block_size):
raise ValueError(f"Block RHT: block_size must be power of 2, got {block_size}.")
if dim % block_size != 0:
raise RuntimeError(
f"Block RHT: inputs.shape[-1]={dim} is not divisible by block_size={block_size}. "
f"Use a block_size that divides {dim} (e.g. {_largest_pow2_divisor(dim)})."
)
n_blocks = dim // block_size
# Reshape to (..., n_blocks, block_size)
flat = inputs.reshape(-1, dim)
blocks = flat.reshape(-1, n_blocks, block_size)
# Apply FHT per block (last dim)
rotated = FastHadamardTransform.apply(blocks) / torch.sqrt(
torch.tensor(block_size, dtype=torch.float32)
)
outputs = rotated.reshape(inputs.shape)

return outputs.to(dtype) if rotate_fp32 else outputs
35 changes: 27 additions & 8 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from ... import calib
from ... import utils as quant_utils
from ...config import QuantizerAttributeConfig
from ...config import QuantizerAttributeConfig, RotateConfig
from ...qtensor import (
BaseQuantizedTensor,
FP8QTensor,
Expand Down Expand Up @@ -532,16 +532,29 @@ def is_static_block_quant(self):
@property
def rotate_is_enabled(self):
"""Check if rotate is enabled in quant config."""
return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
if isinstance(self._rotate, RotateConfig):
return self._rotate.enable
if isinstance(self._rotate, dict): # backward compat: old checkpoints stored a dict
return self._rotate.get("enable", False)
return self._rotate # bool

@property
def rotate_is_fp32(self):
"""Check if rotation needs to be computed in float32."""
return (
self._rotate.get("rotate_fp32", False)
if isinstance(self._rotate, dict) and self.rotate_is_enabled
else False
)
if isinstance(self._rotate, RotateConfig):
return self._rotate.rotate_fp32 if self._rotate.enable else False
if isinstance(self._rotate, dict) and self.rotate_is_enabled:
return self._rotate.get("rotate_fp32", False)
return False

@property
def rotate_block_size(self):
"""Block size for block-granular RHT, or None for full/auto."""
if isinstance(self._rotate, RotateConfig):
return self._rotate.block_size if self._rotate.enable else None
if isinstance(self._rotate, dict) and self.rotate_is_enabled:
return self._rotate.get("block_size", None)
return None

def disable_calib(self):
"""Disable calibration."""
Expand Down Expand Up @@ -1011,7 +1024,11 @@ def forward(self, inputs):

# Rotating the input
if self.rotate_is_enabled:
inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32)
inputs = normalized_hadamard_transform(
inputs,
rotate_fp32=self.rotate_is_fp32,
block_size=self.rotate_block_size,
)

if self._disabled:
# if quantizer is disabled, we still need to track the input dtype for saving the model
Expand Down Expand Up @@ -1125,6 +1142,8 @@ def extra_repr(self):
)
s += " rotated" if self.rotate_is_enabled else ""
s += " (fp32)" if self.rotate_is_fp32 else ""
if self.rotate_block_size is not None:
s += f" (block={self.rotate_block_size})"
s += (
f" calibrator={self._calibrator.__class__.__name__}"
if (self._calibrator is not None)
Expand Down
15 changes: 15 additions & 0 deletions tests/gpu/torch/quantization/test_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def test_hadamard_transform(dim):
assert torch.allclose(xxt_h_fp32, xxt, atol=0.05)


@pytest.mark.parametrize(
("dim", "block_size"),
[(1920, 128), (1536, 128), (1920, None), (64, 32)],
)
def test_hadamard_transform_block(dim, block_size):
"""Block-granular RHT for non-power-of-2 dimensions (e.g. MoE expert channels)."""
x = torch.rand(4, dim, device="cuda")
xxt = x @ x.T
x_h = normalized_hadamard_transform(x, block_size=block_size)
xxt_h = x_h @ x_h.T
# Use rtol instead of atol: float32 accumulated error scales with value magnitude,
# which grows with dim. 1e-3 relative tolerance is appropriate for float32 block RHT.
assert torch.allclose(xxt_h, xxt, rtol=1e-3, atol=1e-6)


@pytest.mark.parametrize(
"rotate_fp32",
[True, False],
Expand Down
Loading