Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ Changelog
- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache.
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).
- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for details.
- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data).
- Add skip-softmax calibration *through the vLLM integration*. The Triton ``attention_calibrate`` kernel now supports vLLM's paged KV cache, and ``ModelOptSparseAttentionImpl`` gains a calibration mode that measures multi-threshold tile-skip statistics over the paged cache (prefill and decode) while still returning dense attention. ``examples/vllm_serve/calibrate_sparse_attn.py`` (with ``sparse_attn_calib_worker.py``) drives calibration over prompts via ``LLM.generate``, fits the exponential ``(a, b)`` model, and writes the same ``sparse_attention_config`` block ``hf_sa.py`` produces so ``vllm_serve_sparse_attn.py`` serves it unchanged.
- Support the **FlashInfer** attention backend for skip-softmax calibration **and serving** (in addition to FlashAttention). A backend-agnostic ``_SparseCalibrationMixin`` shares both the per-request calibration measurement and the sparse-inference path; ``ModelOptSparseFlashInferImpl`` reads FlashInfer's ``[num_blocks, 2, page_size, ...]`` paged cache, and ``patch_flashinfer_metadata_builder`` exposes the dense paged metadata FlashInfer otherwise keeps only inside its planned wrappers. Both the calibration worker and the serving worker (``examples/vllm_serve/sparse_attn_worker.py``) auto-select the matching sparse impl per attention layer via ``select_sparse_impl_cls``; pass ``--attention_backend FLASHINFER`` to force it (FlashInfer needs a supported ``head_size``).

**Bug Fixes**

- Fix the PyTorch ``flash_skip_softmax`` skip-softmax calibration to exclude padded query rows when the sequence length is not a multiple of the block size. ``_reshape_to_blocks`` pads the last query-block-row with ``dtype.min``, so a fully-padded row had ``block_diff == 0`` and always voted "keep" in the block reduction — the last partial block-row was never skipped, under-counting sparsity by up to one block-row (~0.1 absolute for long prompts) and skewing the fitted ``(a, b)``. This made HF (PyTorch) calibration disagree with the Triton/vLLM kernel on real models; the two now match to <0.01 at any sequence length. The cross-validation tests now run at non-multiple-of-128 lengths to guard the regression.
- vLLM skip-softmax calibration now averages per-threshold sparsity across layers per sample (matching the HF ``DynamicThresholdCalibrator`` aggregation) instead of pooling each ``(layer, sample)`` independently, so vLLM- and HF-calibrated ``(a, b)`` agree.
- In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False.
- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.

Expand Down
19 changes: 19 additions & 0 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_CALIB_SPARSE24,
SKIP_SOFTMAX_TRITON_CALIB,
SPARSE_SOFTMAX_DEFAULT,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
Expand All @@ -44,6 +45,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24,
"skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB,
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
}

Expand Down Expand Up @@ -186,6 +188,15 @@ def main(args):
calib["max_seqlen"] = args.calib_max_seqlen
if args.calib_chunk_size is not None:
calib["chunk_size"] = args.calib_chunk_size
# Point RULER calibration at the data downloaded by download_ruler_data.sh
# (next to this script) unless the user overrides it. The NIAH essay
# haystack requires this directory.
calib.setdefault(
"data_dir",
args.calib_data_dir
if args.calib_data_dir is not None
else str(Path(__file__).parent / "data"),
)

model = mtsa.sparsify(model, config=sparse_config)
print("Sparse attention applied successfully!")
Expand Down Expand Up @@ -302,6 +313,14 @@ def main(args):
default=None,
help="Chunk size for calibration prefill. Overrides config value.",
)
parser.add_argument(
"--calib_data_dir",
type=str,
default=None,
help="Path to RULER calibration data (contains an 'essays' subdir). "
"Defaults to the 'data' directory next to this script "
"(populated by download_ruler_data.sh).",
)

args = parser.parse_args()
main(args)
20 changes: 18 additions & 2 deletions examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ QUANT_CFG=<quant_cfg> QUANT_FILE_PATH=<quantizer_state.pth> python vllm_serve_fa

## Serve a model with sparse attention in vLLM

Apply ModelOpt sparse attention at serve time. The launcher replaces vLLM's `FlashAttentionImpl` with `ModelOptSparseAttentionImpl` (Triton kernel with paged KV cache support) on every attention layer right after model load.
Apply ModelOpt sparse attention at serve time. The launcher swaps the ModelOpt sparse impl (Triton kernel with paged KV cache support) onto every attention layer right after model load — `ModelOptSparseAttentionImpl` for the **FlashAttention** backend and `ModelOptSparseFlashInferImpl` for the **FlashInfer** backend (auto-selected per layer; pass `--attention-backend FLASHINFER` to force FlashInfer, which needs a supported `head_size`).

The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export. The launcher restores calibrated skip-softmax metadata and N:M sparse-softmax metadata (`sparsity_n`, `sparsity_m`, `dense_sink_tokens`, `dense_recent_tokens`). Checkpoints exported with both metadata entries use ModelOpt Triton for sparse prefill launches; decode-only launches and launches without active sparse work delegate back to vLLM FlashAttention.
The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export. The launcher restores calibrated skip-softmax metadata and N:M sparse-softmax metadata (`sparsity_n`, `sparsity_m`, `dense_sink_tokens`, `dense_recent_tokens`). Checkpoints exported with both metadata entries use ModelOpt Triton for sparse prefill launches; decode-only launches and launches without active sparse work delegate back to the native backend.

Workflow:

Expand All @@ -121,6 +121,22 @@ Limitations:
- vLLM V1 chunked prefill and prefix-cache suffix attention are supported by offsetting query positions into the longer KV span.
- CUDA graph capture is not validated yet — use `--enforce-eager`.

### Calibrate skip-softmax thresholds in vLLM

Step 1 above (calibrating with `hf_sa.py`) runs in HuggingFace. You can also calibrate the skip-softmax threshold **directly through vLLM**, using the same Triton calibration kernel over vLLM's paged KV cache. `calibrate_sparse_attn.py` force-swaps `ModelOptSparseAttentionImpl` (in calibration mode) onto every attention layer, runs your prompts through `LLM.generate`, fits the exponential model `scale_factor = a * exp(b * sparsity)` for both prefill and decode, and writes the resulting `sparse_attention_config` — the same block `hf_sa.py` produces:

```bash
python calibrate_sparse_attn.py <CKPT> \
--prompts_file prompts.txt \
--target_sparse_ratio 0.5 \
--decode_tokens 32 \
--update_checkpoint_config
```

`--prompts_file` is one prompt per line (longer, varied-length prompts give a better fit). `--update_checkpoint_config` merges the fitted config into `<CKPT>/config.json` in place; without it, the config is only dumped to `sparse_attention_config.json` for inspection. The calibration kernel computes full (dense) attention while it measures, so generated tokens are unaffected — only tile-skip statistics are recorded. Afterward, serve the checkpoint with `vllm_serve_sparse_attn.py` as above.

Both the **FlashAttention** and **FlashInfer** backends are supported; the worker auto-selects the matching impl per attention layer (and prints the active impl, e.g. `{'ModelOptSparseFlashInferImpl': N}`, so you can confirm the backend in use). Models that default to FlashInfer (e.g. NemotronH) need no override; to force it on others, pass `--attention_backend FLASHINFER` (this vLLM version takes the backend via the engine arg, **not** a `VLLM_ATTENTION_BACKEND` env var). FlashInfer requires a supported `head_size` (64/128/...); unsupported sizes fall back / error at load. The fitted `(a, b)` are backend-independent (they measure attention scores at a fixed 128×128 tile granularity), so a checkpoint calibrated under one backend serves correctly under the other.

## Known Problems

1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align).
Expand Down
Loading