diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b6a3a979dd5..92de38fe2f2 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -43,9 +43,14 @@ Changelog - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). - Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md `_ for details. +- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data). +- Add skip-softmax calibration *through the vLLM integration*. The Triton ``attention_calibrate`` kernel now supports vLLM's paged KV cache, and ``ModelOptSparseAttentionImpl`` gains a calibration mode that measures multi-threshold tile-skip statistics over the paged cache (prefill and decode) while still returning dense attention. ``examples/vllm_serve/calibrate_sparse_attn.py`` (with ``sparse_attn_calib_worker.py``) drives calibration over prompts via ``LLM.generate``, fits the exponential ``(a, b)`` model, and writes the same ``sparse_attention_config`` block ``hf_sa.py`` produces so ``vllm_serve_sparse_attn.py`` serves it unchanged. +- Support the **FlashInfer** attention backend for skip-softmax calibration **and serving** (in addition to FlashAttention). A backend-agnostic ``_SparseCalibrationMixin`` shares both the per-request calibration measurement and the sparse-inference path; ``ModelOptSparseFlashInferImpl`` reads FlashInfer's ``[num_blocks, 2, page_size, ...]`` paged cache, and ``patch_flashinfer_metadata_builder`` exposes the dense paged metadata FlashInfer otherwise keeps only inside its planned wrappers. Both the calibration worker and the serving worker (``examples/vllm_serve/sparse_attn_worker.py``) auto-select the matching sparse impl per attention layer via ``select_sparse_impl_cls``; pass ``--attention_backend FLASHINFER`` to force it (FlashInfer needs a supported ``head_size``). **Bug Fixes** +- Fix the PyTorch ``flash_skip_softmax`` skip-softmax calibration to exclude padded query rows when the sequence length is not a multiple of the block size. ``_reshape_to_blocks`` pads the last query-block-row with ``dtype.min``, so a fully-padded row had ``block_diff == 0`` and always voted "keep" in the block reduction — the last partial block-row was never skipped, under-counting sparsity by up to one block-row (~0.1 absolute for long prompts) and skewing the fitted ``(a, b)``. This made HF (PyTorch) calibration disagree with the Triton/vLLM kernel on real models; the two now match to <0.01 at any sequence length. The cross-validation tests now run at non-multiple-of-128 lengths to guard the regression. +- vLLM skip-softmax calibration now averages per-threshold sparsity across layers per sample (matching the HF ``DynamicThresholdCalibrator`` aggregation) instead of pooling each ``(layer, sample)`` independently, so vLLM- and HF-calibrated ``(a, b)`` agree. - In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False. - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 5eae54ba6ee..1eacc6f18b3 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_CALIB_SPARSE24, + SKIP_SOFTMAX_TRITON_CALIB, SPARSE_SOFTMAX_DEFAULT, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -44,6 +45,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax_calib": SKIP_SOFTMAX_CALIB, "skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24, + "skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB, "sparse_softmax": SPARSE_SOFTMAX_DEFAULT, } @@ -186,6 +188,15 @@ def main(args): calib["max_seqlen"] = args.calib_max_seqlen if args.calib_chunk_size is not None: calib["chunk_size"] = args.calib_chunk_size + # Point RULER calibration at the data downloaded by download_ruler_data.sh + # (next to this script) unless the user overrides it. The NIAH essay + # haystack requires this directory. + calib.setdefault( + "data_dir", + args.calib_data_dir + if args.calib_data_dir is not None + else str(Path(__file__).parent / "data"), + ) model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") @@ -302,6 +313,14 @@ def main(args): default=None, help="Chunk size for calibration prefill. Overrides config value.", ) + parser.add_argument( + "--calib_data_dir", + type=str, + default=None, + help="Path to RULER calibration data (contains an 'essays' subdir). " + "Defaults to the 'data' directory next to this script " + "(populated by download_ruler_data.sh).", + ) args = parser.parse_args() main(args) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 4e23a56a288..96eb522ca89 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -101,9 +101,9 @@ QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fa ## Serve a model with sparse attention in vLLM -Apply ModelOpt sparse attention at serve time. The launcher replaces vLLM's `FlashAttentionImpl` with `ModelOptSparseAttentionImpl` (Triton kernel with paged KV cache support) on every attention layer right after model load. +Apply ModelOpt sparse attention at serve time. The launcher swaps the ModelOpt sparse impl (Triton kernel with paged KV cache support) onto every attention layer right after model load — `ModelOptSparseAttentionImpl` for the **FlashAttention** backend and `ModelOptSparseFlashInferImpl` for the **FlashInfer** backend (auto-selected per layer; pass `--attention-backend FLASHINFER` to force FlashInfer, which needs a supported `head_size`). -The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export. The launcher restores calibrated skip-softmax metadata and N:M sparse-softmax metadata (`sparsity_n`, `sparsity_m`, `dense_sink_tokens`, `dense_recent_tokens`). Checkpoints exported with both metadata entries use ModelOpt Triton for sparse prefill launches; decode-only launches and launches without active sparse work delegate back to vLLM FlashAttention. +The configuration is read from the checkpoint's `config.json` `sparse_attention_config` block, written by ModelOpt's HF export. The launcher restores calibrated skip-softmax metadata and N:M sparse-softmax metadata (`sparsity_n`, `sparsity_m`, `dense_sink_tokens`, `dense_recent_tokens`). Checkpoints exported with both metadata entries use ModelOpt Triton for sparse prefill launches; decode-only launches and launches without active sparse work delegate back to the native backend. Workflow: @@ -121,6 +121,22 @@ Limitations: - vLLM V1 chunked prefill and prefix-cache suffix attention are supported by offsetting query positions into the longer KV span. - CUDA graph capture is not validated yet — use `--enforce-eager`. +### Calibrate skip-softmax thresholds in vLLM + +Step 1 above (calibrating with `hf_sa.py`) runs in HuggingFace. You can also calibrate the skip-softmax threshold **directly through vLLM**, using the same Triton calibration kernel over vLLM's paged KV cache. `calibrate_sparse_attn.py` force-swaps `ModelOptSparseAttentionImpl` (in calibration mode) onto every attention layer, runs your prompts through `LLM.generate`, fits the exponential model `scale_factor = a * exp(b * sparsity)` for both prefill and decode, and writes the resulting `sparse_attention_config` — the same block `hf_sa.py` produces: + +```bash +python calibrate_sparse_attn.py \ + --prompts_file prompts.txt \ + --target_sparse_ratio 0.5 \ + --decode_tokens 32 \ + --update_checkpoint_config +``` + +`--prompts_file` is one prompt per line (longer, varied-length prompts give a better fit). `--update_checkpoint_config` merges the fitted config into `/config.json` in place; without it, the config is only dumped to `sparse_attention_config.json` for inspection. The calibration kernel computes full (dense) attention while it measures, so generated tokens are unaffected — only tile-skip statistics are recorded. Afterward, serve the checkpoint with `vllm_serve_sparse_attn.py` as above. + +Both the **FlashAttention** and **FlashInfer** backends are supported; the worker auto-selects the matching impl per attention layer (and prints the active impl, e.g. `{'ModelOptSparseFlashInferImpl': N}`, so you can confirm the backend in use). Models that default to FlashInfer (e.g. NemotronH) need no override; to force it on others, pass `--attention_backend FLASHINFER` (this vLLM version takes the backend via the engine arg, **not** a `VLLM_ATTENTION_BACKEND` env var). FlashInfer requires a supported `head_size` (64/128/...); unsupported sizes fall back / error at load. The fitted `(a, b)` are backend-independent (they measure attention scores at a fixed 128×128 tile granularity), so a checkpoint calibrated under one backend serves correctly under the other. + ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). diff --git a/examples/vllm_serve/calibrate_sparse_attn.py b/examples/vllm_serve/calibrate_sparse_attn.py new file mode 100644 index 00000000000..57c6020f6d8 --- /dev/null +++ b/examples/vllm_serve/calibrate_sparse_attn.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibrate skip-softmax thresholds *through vLLM* and write the serving config. + +Runs calibration prompts through a vLLM ``LLM`` whose attention layers use +``ModelOptSparseAttentionImpl`` in calibration mode (see +``sparse_attn_calib_worker.py``). The ModelOpt Triton calibration kernel +measures, per candidate threshold, how many KV tiles would be skipped — over +the paged KV cache, for both prefill and decode — then fits the exponential +model ``scale_factor = a * exp(b * sparsity)``. + +The fitted ``(a, b)`` per phase are written as a ``sparse_attention_config`` +block, the same one ``hf_sa.py`` produces, so ``vllm_serve_sparse_attn.py`` can +serve the calibrated model directly. + +Usage: + python calibrate_sparse_attn.py \ + --prompts_file prompts.txt \ + --target_sparse_ratio 0.5 \ + --decode_tokens 32 \ + --update_checkpoint_config + +``--prompts_file`` is one prompt per line; longer, varied-length prompts give a +better fit. With no file, a tiny built-in demo set is used (fine for a smoke +test, not for a real fit). +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +_DEMO_PROMPTS = [ + "Summarize the history of computing in a few paragraphs. " * 40, + "Explain how attention works in transformer models. " * 60, + "Write a detailed essay about renewable energy sources. " * 80, +] + + +def _load_prompts(prompts_file: str | None) -> list[str]: + if prompts_file is None: + print( + "[ModelOpt] No --prompts_file given; using a tiny built-in demo set. " + "Pass real, varied-length prompts for a usable fit." + ) + return _DEMO_PROMPTS + lines = [ln.strip() for ln in Path(prompts_file).read_text().splitlines() if ln.strip()] + if not lines: + raise ValueError(f"No prompts found in {prompts_file}") + print(f"[ModelOpt] Loaded {len(lines)} calibration prompts from {prompts_file}") + return lines + + +def _write_config(ckpt: str, sparse_config: dict, update_checkpoint: bool) -> None: + """Dump the sparse_attention_config and optionally merge into config.json.""" + out_path = Path("sparse_attention_config.json") + out_path.write_text(json.dumps(sparse_config, indent=2)) + print(f"[ModelOpt] Wrote calibrated config to {out_path.resolve()}") + + if not update_checkpoint: + print( + "[ModelOpt] Re-run with --update_checkpoint_config to merge this into " + f"{ckpt}/config.json (required for vllm_serve_sparse_attn.py to pick it up)." + ) + return + + config_json = Path(ckpt) / "config.json" + config = json.loads(config_json.read_text()) + config["sparse_attention_config"] = sparse_config + config_json.write_text(json.dumps(config, indent=2)) + print(f"[ModelOpt] Merged sparse_attention_config into {config_json}") + + +def main(): + parser = argparse.ArgumentParser(description="Calibrate skip-softmax thresholds via vLLM") + parser.add_argument("model", type=str, help="Path to the HF checkpoint to calibrate") + parser.add_argument("--prompts_file", type=str, default=None, help="One prompt per line") + parser.add_argument( + "--target_sparse_ratio", + type=float, + default=0.5, + help="Target sparsity baked into the exported config (applied to both phases)", + ) + parser.add_argument( + "--decode_tokens", + type=int, + default=32, + help="Decode tokens to generate per prompt (drives decode-phase calibration)", + ) + parser.add_argument( + "--max_model_len", type=int, default=None, help="vLLM max_model_len override" + ) + parser.add_argument( + "--tensor_parallel_size", type=int, default=1, help="vLLM tensor-parallel size" + ) + parser.add_argument( + "--gpu_memory_utilization", + type=float, + default=None, + help="vLLM GPU memory utilization fraction", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Trust remote code for custom model classes (e.g. NemotronH)", + ) + parser.add_argument("--dtype", type=str, default=None, help="Model dtype, e.g. bfloat16") + parser.add_argument( + "--attention_backend", + type=str, + default=None, + help="Force the vLLM attention backend, e.g. FLASH_ATTN or FLASHINFER. " + "Default: let vLLM choose (the worker supports whichever of FlashAttention " + "/ FlashInfer is selected). FlashInfer needs a supported head_size (64/128/...).", + ) + parser.add_argument( + "--fit_logspace", + action="store_true", + help="Fit the exponential model in log space (wide scale_factor ranges)", + ) + parser.add_argument( + "--update_checkpoint_config", + action="store_true", + help="Merge the calibrated config into /config.json in place", + ) + args = parser.parse_args() + + # Workers run in separate processes and must import the calibration worker. + repo_root = str(Path(__file__).resolve().parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + current = os.environ.get("PYTHONPATH") + os.environ["PYTHONPATH"] = os.pathsep.join([current, repo_root]) if current else repo_root + + from vllm import LLM, SamplingParams + + prompts = _load_prompts(args.prompts_file) + + llm_kwargs = { + "model": args.model, + "worker_cls": "sparse_attn_calib_worker.SparseAttnCalibWorker", + # Calibration swaps the attention impl per layer; eager avoids CUDA-graph + # capture of the (now Triton) attention path. + "enforce_eager": True, + } + if args.max_model_len is not None: + llm_kwargs["max_model_len"] = args.max_model_len + if args.tensor_parallel_size and args.tensor_parallel_size > 1: + llm_kwargs["tensor_parallel_size"] = args.tensor_parallel_size + if args.gpu_memory_utilization is not None: + llm_kwargs["gpu_memory_utilization"] = args.gpu_memory_utilization + if args.trust_remote_code: + llm_kwargs["trust_remote_code"] = True + if args.dtype is not None: + llm_kwargs["dtype"] = args.dtype + # The calib worker auto-detects the per-layer backend (FlashAttention / + # FlashInfer) and swaps in the matching sparse impl. Pass --attention_backend + # only to force a choice: models that default to FlashInfer (e.g. NemotronH) + # need no override; others (e.g. Llama, which defaults to FlashAttention) + # need ``--attention_backend FLASHINFER`` to calibrate under FlashInfer. + # NOTE: this vLLM version takes the backend via this engine arg, not the + # (removed) VLLM_ATTENTION_BACKEND env var. + if args.attention_backend is not None: + llm_kwargs["attention_backend"] = args.attention_backend + llm = LLM(**llm_kwargs) + + n_layers = llm.collective_rpc("sparse_calib_enable")[0] + print(f"[ModelOpt] Calibration enabled on {n_layers} attention layers") + # Surface which sparse impl is active so the backend in use is verifiable + # (e.g. {'ModelOptSparseFlashInferImpl': N} confirms the FlashInfer path). + status = llm.collective_rpc("sparse_calib_status")[0] + print(f"[ModelOpt] Active sparse impls: {status['impl_types']}") + if n_layers == 0: + print( + "[ModelOpt] No layers were swapped — the model's attention backend is " + "unsupported. Try --attention_backend FLASH_ATTN or FLASHINFER." + ) + return + + # generate() drives prefill (prefill-phase stats) then decode_tokens decode + # steps (decode-phase stats). The calibration kernel computes full attention, + # so the generated text is unaffected — only tile-skip counts are recorded. + sampling = SamplingParams(temperature=0.0, max_tokens=args.decode_tokens) + llm.generate(prompts, sampling) + + sparse_config = llm.collective_rpc( + "sparse_calib_fit", + args=({"prefill": args.target_sparse_ratio, "decode": args.target_sparse_ratio},), + kwargs={"fit_logspace": args.fit_logspace}, + )[0] + + if sparse_config is None: + print( + "[ModelOpt] Calibration produced no valid fit — try more/longer prompts " + "so observed sparsity spans the (10%, 90%) fitting window." + ) + return + + print("[ModelOpt] Calibrated threshold_scale_factor:") + print(json.dumps(sparse_config["threshold_scale_factor"], indent=2)) + _write_config(args.model, sparse_config, args.update_checkpoint_config) + + +if __name__ == "__main__": + main() diff --git a/examples/vllm_serve/sparse_attn_calib_worker.py b/examples/vllm_serve/sparse_attn_calib_worker.py new file mode 100644 index 00000000000..b92fbc61bd4 --- /dev/null +++ b/examples/vllm_serve/sparse_attn_calib_worker.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom vLLM worker that calibrates skip-softmax thresholds in-engine. + +Unlike ``SparseAttnWorker`` (which reads an already-calibrated +``sparse_attention_config`` from the checkpoint and serves), this worker +*produces* that config. It force-swaps ``ModelOptSparseAttentionImpl`` onto +every attention layer and exposes RPC methods the driver +(``calibrate_sparse_attn.py``) calls via ``LLM.collective_rpc``: + +- ``sparse_calib_enable``: put every layer's impl in calibration mode. +- ``sparse_calib_fit``: stop measuring, fit the exponential ``(a, b)`` model + from the tile-skip stats collected during ``llm.generate``, and return an + export-format ``sparse_attention_config`` dict. + +Calibration uses the ModelOpt Triton calibration kernel through the paged KV +cache (see ``modelopt.torch.sparsity.attention_sparsity.plugins.vllm``), so the +numbers match the HF calibration path and drop straight into the existing +serving path. +""" + +import importlib +from typing import Any + +try: + _has_legacy_attention_layer = importlib.util.find_spec("vllm.attention.layer") is not None +except (ModuleNotFoundError, ValueError): + _has_legacy_attention_layer = False + +if _has_legacy_attention_layer: + from vllm.attention.layer import Attention as VLLMAttention +else: + from vllm.model_executor.layers.attention import Attention as VLLMAttention + +from vllm.v1.worker.gpu_worker import Worker as BaseWorker + +import modelopt +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + _clone_sparse_impl, + disable_calibration, + enable_calibration, + fit_calibration, + iter_sparse_impls, + select_sparse_impl_cls, +) + +# Default threshold sweep — should span sparsities from ~10% to ~95%. +DEFAULT_THRESHOLD_TRIALS = [ + 1e-4, + 1e-3, + 5e-3, + 1e-2, + 3e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, + 9e-1, +] + + +def _force_replace_attention_impls(worker) -> int: + """Swap the ModelOpt sparse impl onto every supported attention layer. + + Calibration has no checkpoint metadata to match against, so every attention + layer is converted unconditionally (with empty ``sparse_kw``; calibration + mode is toggled separately by ``sparse_calib_enable``). Supports the + FlashAttention and FlashInfer backends; other backends are left untouched. + """ + model = worker.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + patched = 0 + skipped_backends: set[str] = set() + for _, module in model.named_modules(): + if not isinstance(module, VLLMAttention): + continue + impl = module.impl + new_cls = select_sparse_impl_cls(impl) + if new_cls is None: + if type(impl).__name__ not in ( + "ModelOptSparseAttentionImpl", + "ModelOptSparseFlashInferImpl", + ): + skipped_backends.add(type(impl).__name__) + continue + try: + new_impl = _clone_sparse_impl(impl, new_cls) + except NotImplementedError: + # e.g. FlashAttention sinks — leave those layers on vLLM's impl. + skipped_backends.add(type(impl).__name__) + continue + new_impl.sparse_kw = {} + module.impl = new_impl + patched += 1 + print(f"[ModelOpt] Calibration: swapped impl on {patched} attention layers") + if skipped_backends: + print( + f"[ModelOpt] Calibration: left {sorted(skipped_backends)} layers unchanged " + "(unsupported backend — calibrate under FLASH_ATTN or FLASHINFER)." + ) + return patched + + +class SparseAttnCalibWorker(BaseWorker): + """vLLM worker that calibrates skip-softmax thresholds through the engine.""" + + def load_model(self, *args, **kwargs) -> None: + """Load the model, then force the sparse impl onto every attention layer.""" + super().load_model(*args, **kwargs) + _force_replace_attention_impls(self) + + # -- RPC methods (invoked via LLM.collective_rpc) ---------------------- + + def sparse_calib_status(self) -> dict[str, Any]: + """Report which sparse impls are active and how many records each holds. + + Lets the driver confirm calibration actually routes through the expected + backend (e.g. ``ModelOptSparseFlashInferImpl``) rather than a fallback. + """ + impls = list(iter_sparse_impls(self.model_runner.model)) + impl_types: dict[str, int] = {} + total_records = 0 + for impl in impls: + impl_types[type(impl).__name__] = impl_types.get(type(impl).__name__, 0) + 1 + total_records += len(getattr(impl, "_calib_records", [])) + return { + "num_sparse_layers": len(impls), + "impl_types": impl_types, + "calibrating": any(getattr(i, "_calibrate", False) for i in impls), + "total_records": total_records, + } + + def sparse_calib_enable(self, threshold_trials: list[float] | None = None) -> int: + """Enter calibration mode on all sparse impls; returns layer count.""" + trials = threshold_trials or DEFAULT_THRESHOLD_TRIALS + impls = list(iter_sparse_impls(self.model_runner.model)) + enable_calibration(impls, trials) + return len(impls) + + def sparse_calib_fit( + self, + target_sparse_ratio: dict[str, float] | float = 0.5, + threshold_trials: list[float] | None = None, + fit_logspace: bool = False, + ) -> dict[str, Any] | None: + """Stop measuring, fit ``(a, b)``, and return an export-format config. + + Returns ``None`` if no phase produced a valid fit (e.g. too little data). + """ + trials = threshold_trials or DEFAULT_THRESHOLD_TRIALS + impls = list(iter_sparse_impls(self.model_runner.model)) + disable_calibration(impls) + calibration_params = fit_calibration(impls, trials, fit_logspace=fit_logspace) + if not calibration_params: + return None + return _build_sparse_attention_config(calibration_params, target_sparse_ratio) + + +def _normalize_target(target_sparse_ratio: dict[str, float] | float) -> dict[str, float]: + if isinstance(target_sparse_ratio, (int, float)): + return {"prefill": float(target_sparse_ratio), "decode": float(target_sparse_ratio)} + return { + "prefill": float(target_sparse_ratio.get("prefill", 0.5)), + "decode": float(target_sparse_ratio.get("decode", 0.5)), + } + + +def _build_sparse_attention_config( + calibration_params: dict[str, dict[str, float]], + target_sparse_ratio: dict[str, float] | float, +) -> dict[str, Any]: + """Build the ``sparse_attention_config`` block consumed at serving time. + + Matches ``modelopt.torch.sparsity.attention_sparsity.conversion. + export_sparse_attention_config`` so ``load_from_checkpoint_metadata`` (the + serving path) recognizes it without changes. + """ + threshold_scale_factor: dict[str, Any] = {"formula": "a * exp(b * target_sparsity)"} + for phase in ("prefill", "decode"): + if phase in calibration_params: + threshold_scale_factor[phase] = { + "a": calibration_params[phase]["a"], + "b": calibration_params[phase]["b"], + } + return { + "config_groups": { + "group_0": {"sparse_algo": "softmax_skip", "targets": ["Attention"]}, + }, + "threshold_scale_factor": threshold_scale_factor, + "target_sparse_ratio": _normalize_target(target_sparse_ratio), + "producer": {"name": "modelopt", "version": modelopt.__version__}, + } diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py index 1057baa870e..6a147b9749e 100644 --- a/examples/vllm_serve/sparse_attn_worker.py +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -54,14 +54,16 @@ from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( _build_sparse_kw, _clone_sparse_impl, + select_sparse_impl_cls, ) def _replace_attention_impl(worker): - """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. + """Replace the attention impl with the ModelOpt sparse impl on all Attention layers. - The sole configuration source is the checkpoint's ``sparse_attention_config`` - metadata. No-op if the checkpoint has no such block. + Supports the FlashAttention and FlashInfer backends (the matching sparse impl + is selected per layer). The sole configuration source is the checkpoint's + ``sparse_attention_config`` metadata. No-op if the checkpoint has no such block. """ hf_config = getattr(worker.model_runner.model_config, "hf_config", None) detected = load_from_checkpoint_metadata(hf_config) @@ -81,6 +83,7 @@ def _replace_attention_impl(worker): model = model.unwrap() patched = 0 + skipped_backends: set[str] = set() for name, module in model.named_modules(): if not isinstance(module, VLLMAttention): continue @@ -94,11 +97,26 @@ def _replace_attention_impl(worker): # Keep vLLM's original impl when the exported layer config does not # enable any sparse feature. continue - new_impl = _clone_sparse_impl(module.impl) + new_cls = select_sparse_impl_cls(module.impl) + if new_cls is None: + # Unsupported backend (not FlashAttention / FlashInfer) — leave it on + # vLLM's native impl rather than mis-cloning into the wrong base. + skipped_backends.add(type(module.impl).__name__) + continue + try: + new_impl = _clone_sparse_impl(module.impl, new_cls) + except NotImplementedError: + skipped_backends.add(type(module.impl).__name__) + continue new_impl.sparse_kw = sparse_kw module.impl = new_impl patched += 1 print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") + if skipped_backends: + print( + f"[ModelOpt] Sparse attention: left {sorted(skipped_backends)} layers unchanged " + "(unsupported backend — serve under FLASH_ATTN or FLASHINFER)." + ) # --------------------------------------------------------------------------- diff --git a/modelopt/torch/kernels/common/attention/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py index 860c65d6621..10b77f60d1b 100644 --- a/modelopt/torch/kernels/common/attention/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -27,6 +27,10 @@ from modelopt.torch.kernels.common.attention.triton_fa import attention +# Skip-softmax calibration config and counters live on the module's +# ``_sparse_method_instance`` (HF passes the owning module to +# ``triton_attention_forward``), so no separate thread-local state is needed. + def _seq_lens_from_mask( attention_mask: torch.Tensor | None, @@ -105,9 +109,35 @@ def triton_attention_forward( kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) kw["max_input_len_k"] = seq_k - # Sparse attention params + # Sparse-attention method instance. It carries the inference threshold and, + # during calibration, both the calibration config and the accumulated + # tile-skip counters. Available here because HF passes the owning module. method = getattr(module, "_sparse_method_instance", None) + # Calibration mode: run the calibration kernel, which computes full attention + # while counting, per candidate threshold, how many KV tiles would be skipped. + # The sparse-attention kwargs below are intentionally not added in this branch. + if method is not None and getattr(method, "_calibration_mode", False): + trials = getattr(method, "_threshold_trials", None) + # Deferred: the package __init__ imports this module, so importing + # attention_calibrate at module top would be circular. + from modelopt.torch.kernels.common.attention import attention_calibrate + + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) + + # Accumulate counters across all attention calls in this forward pass. + # The method instance is per-module so the accumulator stays on one + # device, but guard the add against a device mismatch just in case. + prev = getattr(method, "_hf_calibration_counters", None) + method._hf_calibration_counters = ( + counters if prev is None else prev + counters.to(prev.device) + ) + method._hf_calibration_seq_k = seq_k + method._hf_calibration_is_decode = is_decode + + return (o.view(batch, seq_len, num_heads, head_dim), None) + # N:M sparse softmax: prefill only (no perf benefit for decode) if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False): kw["sparsity_n"] = method.sparsity_n @@ -115,10 +145,13 @@ def triton_attention_forward( kw["dense_sink_tokens"] = method.dense_sink_tokens kw["dense_recent_tokens"] = method.dense_recent_tokens - # Skip-softmax: applies to both prefill and decode + # Skip-softmax: applies to both prefill and decode. Prefer the method's + # per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back + # to the static threshold when uncalibrated. if method is not None and getattr(module, "_apply_skip_softmax", False): - if method.skip_softmax_threshold: - kw["skip_softmax_threshold"] = method.skip_softmax_threshold + threshold = method.get_inference_threshold(seq_len, seq_k) + if threshold: + kw["skip_softmax_threshold"] = threshold o = attention(q, k, v, **kw) diff --git a/modelopt/torch/kernels/common/attention/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py index 0b481e93558..8a1a521fea6 100644 --- a/modelopt/torch/kernels/common/attention/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None: _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] _MEASURE_BLOCK_M = 128 -_MEASURE_BLOCK_N = 64 +# 128 so the kernel sparsity-measurement block matches the PyTorch +# flash_skip_softmax calibration block (br = bc = 128) and the Triton +# calibration kernel; otherwise the two measure at different granularities. +_MEASURE_BLOCK_N = 128 _MEASURE_NUM_STAGES = 1 _MEASURE_NUM_WARPS = 4 @@ -363,6 +366,8 @@ def _attn_fwd( skip_tile = _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2, Sparsity_total, Sparsity_skipped, @@ -919,23 +924,29 @@ def forward( def grid(META): return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) - if do_measure: - # Runtime counters mutate global tensors, so do not run them through - # autotune candidate trials. Use one stable config for measurement. - _attn_fwd.fn[grid]( - *fwd_args, - **fwd_kwargs, - BLOCK_M=_MEASURE_BLOCK_M, - BLOCK_N=_MEASURE_BLOCK_N, - num_warps=_MEASURE_NUM_WARPS, - num_stages=_MEASURE_NUM_STAGES, - ) - else: - _attn_fwd[grid]( - *fwd_args, - **fwd_kwargs, - # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune - ) + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device so the + # kernel dereferences the right pointers instead of triggering an + # illegal memory access. + with torch.cuda.device(q.device): + if do_measure: + # Runtime counters mutate global tensors, so do not run them through + # autotune candidate trials. Use one stable config for measurement. + _attn_fwd.fn[grid]( + *fwd_args, + **fwd_kwargs, + BLOCK_M=_MEASURE_BLOCK_M, + BLOCK_N=_MEASURE_BLOCK_N, + num_warps=_MEASURE_NUM_WARPS, + num_stages=_MEASURE_NUM_STAGES, + ) + else: + _attn_fwd[grid]( + *fwd_args, + **fwd_kwargs, + # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune + ) # Store sparsity counters on the output tensor for retrieval by callers if do_measure: @@ -970,23 +981,30 @@ def backward(ctx, grad_output): do = grad_output.contiguous() num_warps = 4 + # Triton launches on torch.cuda.current_device(), which is not + # necessarily the device the tensors live on (e.g. under accelerate + # device_map="auto" sharding). Activate the tensor's device for each + # launch so the kernels dereference the right pointers instead of + # triggering an illegal memory access. + # Phase 1: delta = rowsum(O * dO) delta = torch.empty_like(lse) - _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( - o, - do, - delta, - o.stride(0), - o.stride(1), - do.stride(0), - do.stride(1), - delta.stride(0), - delta.stride(1), - q.shape[0], - HEAD_DIM=HEAD_DIM, - BLOCK_D=BLOCK_D, - BLOCK_M=BLOCK, - ) + with torch.cuda.device(q.device): + _attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))]( + o, + do, + delta, + o.stride(0), + o.stride(1), + do.stride(0), + do.stride(1), + delta.stride(0), + delta.stride(1), + q.shape[0], + HEAD_DIM=HEAD_DIM, + BLOCK_D=BLOCK_D, + BLOCK_M=BLOCK, + ) dq = torch.zeros_like(q) dk = torch.zeros_like(k) @@ -1016,57 +1034,59 @@ def backward(ctx, grad_output): ) # Phase 2: dK, dV - _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( - *bwd_args[:4], - dk, - dv, - *bwd_args[4:], - dk.stride(0), - dk.stride(1), - dv.stride(0), - dv.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))]( + *bwd_args[:4], + dk, + dv, + *bwd_args[4:], + dk.stride(0), + dk.stride(1), + dv.stride(0), + dv.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) # Phase 3: dQ - _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( - *bwd_args[:4], - dq, - *bwd_args[4:], - dq.stride(0), - dq.stride(1), - lse.stride(0), - lse.stride(1), - kv_group_num=ctx.kv_group_num, - BLOCK_M=BLOCK, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK, - IS_CAUSAL=ctx.is_causal, - HEAD_DIM=HEAD_DIM, - SPARSITY_N=ctx.sparsity_n, - SPARSITY_M=ctx.sparsity_m, - DENSE_SINK_TOKENS=ctx.dense_sink_tokens, - DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, - APPLY_SKIP_SOFTMAX=ctx.apply_skip, - SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, - num_warps=num_warps, - num_stages=1, - ) + with torch.cuda.device(q.device): + _attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))]( + *bwd_args[:4], + dq, + *bwd_args[4:], + dq.stride(0), + dq.stride(1), + lse.stride(0), + lse.stride(1), + kv_group_num=ctx.kv_group_num, + BLOCK_M=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK, + IS_CAUSAL=ctx.is_causal, + HEAD_DIM=HEAD_DIM, + SPARSITY_N=ctx.sparsity_n, + SPARSITY_M=ctx.sparsity_m, + DENSE_SINK_TOKENS=ctx.dense_sink_tokens, + DENSE_RECENT_TOKENS=ctx.dense_recent_tokens, + APPLY_SKIP_SOFTMAX=ctx.apply_skip, + SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2, + num_warps=num_warps, + num_stages=1, + ) return ( dq, diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 971c423f711..3cbd2197e8a 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -28,7 +28,12 @@ import triton import triton.language as tl -from modelopt.torch.kernels.common.attention.triton_fa import LOG2E, _apply_mask +from modelopt.torch.kernels.common.attention.triton_fa import ( + LOG2E, + _apply_mask, + _load_paged_k_tile, + _load_paged_v_tile, +) # --------------------------------------------------------------------------- @@ -64,6 +69,18 @@ def _attn_fwd_calibrate( HEAD_DIM: tl.constexpr, NUM_THRESHOLDS: tl.constexpr, PADDED_THRESHOLDS: tl.constexpr, # next_power_of_2(NUM_THRESHOLDS) for tl.arange + IS_PAGED: tl.constexpr = False, # Whether K/V are read from a paged KV cache + K_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged K + V_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged V + Block_table=None, # [batch, max_blocks_per_seq] page table + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE: tl.constexpr = 16, + max_blocks_per_seq=0, ): """Forward kernel with multi-threshold sparsity measurement. @@ -111,19 +128,53 @@ def _attn_fwd_calibrate( local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) num_tiles = 0 - kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + # Causal bound: when Q is a suffix of KV (decode: seq_len_q == 1 against a + # long cache; or chunked prefill), the visible KV extends to + # causal_offset + (tile_q + 1) * BLOCK_M. Without the offset the loop stops + # at the first BLOCK_M KV tokens, so decode would only ever measure the + # start of the cache instead of the whole thing. + causal_offset = seq_len_kv - seq_len_q + kv_bound = ( + seq_len_kv + if not IS_CAUSAL + else tl.minimum(causal_offset + (tile_q + 1) * BLOCK_M, seq_len_kv) + ) for kv_start in range(0, kv_bound, BLOCK_N): kv_start = tl.multiple_of(kv_start, BLOCK_N) - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) - - scores = tl.dot(q, k) * qk_scale + # Load K^T [BLOCK_D, BLOCK_N] from paged cache or contiguous K. + if IS_PAGED: + k = _load_paged_k_tile( + K_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) + + # Upcast to bf16 before the matmul so fp8 (e4m3) Q/K — used by fp8-attention + # models — are handled; tl.dot does not accept fp8 operands. For bf16 inputs + # this is a no-op. + scores = tl.dot(q.to(tl.bfloat16), k.to(tl.bfloat16)) * qk_scale scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) tile_row_max = tl.max(scores, 1) @@ -132,7 +183,16 @@ def _attn_fwd_calibrate( # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row # must still be below threshold for the tile to be skippable). - max_gap = tl.max(tile_row_max - row_max) # scalar + # + # Exclude padding Q rows (q_pos >= seq_len_q) from the reduction. Their Q is + # loaded as zeros, so their tile_row_max is ~0 (not -inf), which would + # otherwise dominate the max and force max_gap >= 0 — making every tile + # un-skippable. This matters most for decode (seq_len_q == 1, so 127/128 + # rows are padding) and also fixes the last partial Q tile in prefill when + # seq_len_q is not a multiple of BLOCK_M. + gap = tile_row_max - row_max + gap = tl.where(q_pos < seq_len_q, gap, -float("inf")) + max_gap = tl.max(gap) # scalar skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] local_skipped += skip_mask num_tiles += 1 @@ -145,13 +205,33 @@ def _attn_fwd_calibrate( row_sum = row_sum * correction + l_new acc = acc * correction[:, None] - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) - acc = tl.dot(p.to(v.dtype), v, acc) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(tl.bfloat16), v.to(tl.bfloat16), acc) row_max = m_new # --- Write per-program counters (no atomics, just stores) --- @@ -193,6 +273,10 @@ def attention_calibrate( max_input_len_k: int | None = None, *, threshold_trials: list[float] | None = None, + k_cache: torch.Tensor | None = None, + v_cache: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + page_size: int = 16, ) -> tuple[torch.Tensor, torch.Tensor]: """Flash attention with multi-threshold skip-softmax sparsity measurement. @@ -206,6 +290,15 @@ def attention_calibrate( Same as :func:`modelopt.torch.kernels.common.attention.attention`. threshold_trials: List of threshold values to measure sparsity for. Each value is converted to log2-scaled space for the kernel. + k_cache: Paged K cache ``[num_blocks, page_size, num_kv_heads, head_dim]``. + When provided, K/V are read from the paged cache via ``block_table`` + (vLLM layout) instead of from the contiguous ``k``/``v`` tensors. + ``k``/``v`` are then dummies whose only meaningful dimension is + ``shape[1] == num_kv_heads`` (used to compute the GQA ratio). + v_cache: Paged V cache ``[num_blocks, page_size, num_kv_heads, head_dim]``. + block_table: Page table ``[batch, max_blocks_per_seq]`` mapping each + sequence's block indices to global page IDs. + page_size: Number of tokens per page in the KV cache. Returns: Tuple of (output, sparsity_counters): @@ -217,6 +310,10 @@ def attention_calibrate( if threshold_trials is None or len(threshold_trials) == 0: raise ValueError("threshold_trials must be a non-empty list") + is_paged = k_cache is not None + if is_paged and block_table is None: + raise ValueError("block_table is required when k_cache/v_cache are provided.") + # Calibration has only been validated with uniform-length batches (current # diffusion + RULER paths). Varlen inputs would exercise code paths in the # kernel that have not been tested — fail loudly rather than silently @@ -252,13 +349,20 @@ def attention_calibrate( sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale qk_scale = sm_scale * LOG2E BLOCK_D = triton.next_power_of_2(HEAD_DIM) + # 128x128 to match the PyTorch flash_skip_softmax calibration block (br = bc = 128), + # so Triton-kernel and PyTorch calibration measure sparsity at the same granularity. BLOCK_M = 128 - BLOCK_N = 64 + BLOCK_N = 128 if b_seq_len_k is None: b_seq_len_k = b_seq_len b_start_loc_k = b_start_loc + # Paged mode: KV positions come from block_table, so the contiguous KV + # offsets are unused. Provide a dummy so Triton can compile the tl.load. + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + num_thresholds = len(threshold_trials) # Scores already include sm_scale and LOG2E; convert lambda to log2 space only. @@ -282,38 +386,67 @@ def attention_calibrate( num_programs * num_thresholds, dtype=torch.int32, device=q.device ) - _attn_fwd_calibrate[grid]( - q, - k, - v, - qk_scale, - b_start_loc, - b_seq_len, - b_start_loc_k, - b_seq_len_k, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - threshold_tensor, - per_program_totals, - per_program_skipped, - kv_group_num=kv_group_num, - BLOCK_M=BLOCK_M, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK_N, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - NUM_THRESHOLDS=num_thresholds, - PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), - num_warps=4, - num_stages=1, - ) + # Paged KV cache strides (zeros when not paged; computed here so the type + # narrowing of k_cache/v_cache/block_table is explicit for the kernel call). + if is_paged: + assert k_cache is not None and v_cache is not None and block_table is not None + kc_strides = (k_cache.stride(0), k_cache.stride(1), k_cache.stride(2)) + vc_strides = (v_cache.stride(0), v_cache.stride(1), v_cache.stride(2)) + max_blocks_per_seq = block_table.shape[1] + else: + kc_strides = (0, 0, 0) + vc_strides = (0, 0, 0) + max_blocks_per_seq = 0 + + # Triton launches on torch.cuda.current_device(), which is not necessarily + # the device the tensors live on (e.g. under accelerate device_map="auto" + # sharding). Activate the tensor's device so the kernel dereferences the + # right pointers instead of triggering an illegal memory access. + with torch.cuda.device(q.device): + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + per_program_totals, + per_program_skipped, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), + IS_PAGED=is_paged, + K_cache=k_cache, + V_cache=v_cache, + Block_table=block_table, + stride_kc_block=kc_strides[0], + stride_kc_pos=kc_strides[1], + stride_kc_head=kc_strides[2], + stride_vc_block=vc_strides[0], + stride_vc_pos=vc_strides[1], + stride_vc_head=vc_strides[2], + PAGE_SIZE=page_size, + max_blocks_per_seq=max_blocks_per_seq, + num_warps=4, + num_stages=1, + ) # Reduce across programs: sum per-program counts → [num_thresholds] totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py index aa65fd50a12..044e54b2e8e 100644 --- a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -142,6 +142,8 @@ def _apply_sparse_nm_to_qk_tile( def _skip_softmax_decision( scores, row_max, + q_pos, + seq_len_q, SKIP_THRESHOLD_LOG2: tl.constexpr, Sparsity_total, Sparsity_skipped, @@ -159,16 +161,25 @@ def _skip_softmax_decision( The threshold is converted to the kernel's scaled log2 score space by the Python wrapper so it can be compared directly against ``scores``. + ``q_pos`` (``[BLOCK_M]`` absolute query positions) and the scalar + ``seq_len_q`` identify padding rows. When a tile has fewer than ``BLOCK_M`` + valid queries — decode has one valid query plus ``BLOCK_M - 1`` padding + rows, and the last prefill tile is partial when ``seq_q`` is not a multiple + of ``BLOCK_M`` — the padding rows carry zero scores that are never + negligible versus their own running max and would otherwise veto every + skip. They are forced skippable so the decision reflects only valid rows. + Returns: - True when *all* Q rows in the tile satisfy the skip criterion. + True when *all valid* Q rows in the tile satisfy the skip criterion. When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``. """ tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) - # Per-row: True if row's tile max is negligible vs running max - can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) - # Per-tile: skip entire tile only if ALL rows are negligible + # Per-row: True if the row's tile max is negligible vs running max, OR the + # row is padding (q_pos >= seq_len_q) so it must not veto the tile decision. + can_skip = (tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)) | (q_pos >= seq_len_q) + # Per-tile: skip entire tile only if ALL valid rows are negligible skip_tile = tl.min(can_skip.to(tl.int32)) == 1 if MEASURE_SPARSITY: diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index 51df5bb4d4a..840f757a8c6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -153,9 +153,14 @@ def create_decode_calibration_forward_loop( ) -> Callable: """Create forward loop for decode phase calibration. - Uses SDPA for fast prefill, then switches to eager attention for decode - token generation with softmax hook measurement. (Previously used - ``flash_attention_2`` for prefill, but transformers>=5.0's FA2 path + Uses SDPA for fast prefill (no measurement), then switches to the model's + configured sparse-attention backend for the decode steps so measurement + happens there: ``eager`` for the pytorch backend (F.softmax hook) or + ``modelopt_triton`` for the triton backend (Triton calibration kernel). + The backend is read from ``model.config._attn_implementation``, which + ``sparsify`` already set for the chosen backend. + + (SDPA is used for prefill because transformers>=5.0's FA2 path unconditionally calls ``s_aux.to(query.dtype)`` on the attention-sinks tensor and crashes for models without sinks. SDPA is just as fast for prefill, has no softmax hook, and is version-stable.) @@ -179,7 +184,8 @@ def forward_loop(model: nn.Module) -> None: ) input_ids = inputs["input_ids"].to(device) - # Save original attention implementation + # Save original attention implementation (the sparse-attention backend + # set by sparsify: "eager" for pytorch, "modelopt_triton" for triton). original_attn_impl = getattr(model.config, "_attn_implementation", "eager") with torch.no_grad(): @@ -191,8 +197,10 @@ def forward_loop(model: nn.Module) -> None: next_token = outputs.logits[:, -1:, :].argmax(dim=-1) del outputs # Free large prefill logits [B, seqlen, vocab] before decode loop - # Step 2: Switch to eager for decode (enables softmax hook) - model.config._attn_implementation = "eager" + # Step 2: Switch to the sparse backend for decode so measurement + # happens there (eager -> F.softmax hook; modelopt_triton -> + # Triton calibration kernel). + model.config._attn_implementation = original_attn_impl # Step 3: Manual decode loop for explicit control over token generation # model.generate() method is not used here because it doesn't allow explicit control over KV cache diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index d3ed3303256..31f5d67ea50 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -130,8 +130,6 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic # with one entry per threshold, eliminating the need for repeated forward passes. print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds in one pass...") - all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"} - self._set_thresholds(attention_modules, self.threshold_trials) self._enable_calibration_mode(attention_modules) with torch.no_grad(): @@ -139,6 +137,29 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) self._disable_calibration_mode(attention_modules) + return self.calibrate_from_stats(per_sample_stats, phase) + + def calibrate_from_stats(self, per_sample_stats: list[dict], phase: str) -> dict[str, Any]: + """Fit the exponential model from already-collected per-sample stats. + + This is the backend-agnostic Stage 2/3 of :meth:`calibrate`. The HF and + diffusion paths reach it through :meth:`calibrate` (which runs a + ``forward_loop`` to collect the stats first); the vLLM path collects the + stats itself — one record per scheduled request — and calls this directly + so both paths share the same exponential fit. + + Args: + per_sample_stats: List of ``{"sparsity": [s_0, ..., s_n], "sample_length": L}`` + records, one per calibration sample. ``sparsity`` holds the + skipped-tile fraction at each threshold in ``threshold_trials`` + (same order, same length). + phase: Phase being calibrated ('prefill' or 'decode'). + + Returns: + Dict with calibration results including a, b, r_squared, and num_data_points. + """ + all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"} + for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] @@ -153,6 +174,12 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic } ) + # Per-sample measured sparsity (one row per calibration sample: its + # skipped-tile fraction at every threshold). Printed before the fit- + # validity guard so the raw per-sample data is visible even when the fit + # bails (e.g. degenerate near-zero sparsity). + self._print_per_sample_sparsity(per_sample_stats, phase) + if len(all_data_points) < 10: warnings.warn( f"Not enough data points for {phase} calibration. " @@ -285,8 +312,29 @@ def exponential(sparsity, a, b): "calibration_type": "exponential", "min_observed_sparsity": min_observed_sparsity, "max_observed_sparsity": max_observed_sparsity, + # Raw per-sample measured sparsity, so callers can audit the spread + # across samples (not just the fitted average). + "per_sample_sparsity": [ + { + "sample_length": s.get("sample_length", 0), + "sparsity": list(s.get("sparsity", [])), + } + for s in per_sample_stats + ], } + def _print_per_sample_sparsity(self, per_sample_stats: list[dict], phase: str) -> None: + """Print each sample's measured skipped-tile fraction at every threshold.""" + if not per_sample_stats: + return + print(f"\nPer-sample {phase} sparsity (skipped-tile fraction per threshold):") + header = " ".join(f"{t:>7.0e}" for t in self.threshold_trials) + print(f" {'sample':>6} {'length':>8} {header}") + for idx, stat in enumerate(per_sample_stats): + sparsity = stat.get("sparsity", []) + row = " ".join(f"{s:>7.2%}" for s in sparsity) + print(f" {idx:>6} {stat.get('sample_length', 0):>8} {row}") + def _enable_calibration_mode(self, modules: list[nn.Module]): """Enable calibration mode on sparse attention modules.""" for idx, module in enumerate(modules): diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 32a49f02e34..c064fd0014d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -546,6 +546,35 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): } +# RULER calibration via the fused Triton calibration kernel (prefill + decode). +# Computes the same exponential-model calibration as SKIP_SOFTMAX_CALIB but +# measures tile-skip statistics with the Triton ``attention_calibrate`` kernel +# (the way the Triton inference kernel actually skips tiles) instead of the +# PyTorch F.softmax-patching block path. Faster on GPU since it avoids +# materializing per-block tensors. +SKIP_SOFTMAX_TRITON_CALIB = { + "sparse_cfg": { + "calibration": { + # Prefill calibration uses full-prefill forwards; decode calibration + # runs SDPA prefill followed by Triton-backend decode steps. + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, + "samples": 64, + "max_seqlen": 16384, + # Full prefill (seq_q == seq_k, uniform batch=1) — what + # attention_calibrate was validated against. Chunked prefill would + # exercise an untested KV-cache causal-offset path in the kernel. + "chunk_size": -1, + }, + "*attn*": { + "method": "triton_skip_softmax", + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + }, +} + + class VSAAttributeConfig(ModeloptBaseConfig): """Video Sparse Attention (VSA) attribute configuration. @@ -738,6 +767,7 @@ class VSAConfig(SparseAttentionConfig): "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_CALIB_SPARSE24", "SKIP_SOFTMAX_DEFAULT", + "SKIP_SOFTMAX_TRITON_CALIB", "SKIP_SOFTMAX_TRITON_DEFAULT", "SPARSE_SOFTMAX_DEFAULT", "VSA_DEFAULT", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index c1d6465ba66..baefbd7058d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -193,6 +193,16 @@ def calc_correction_factor_and_p( dense_blocks_list = [] block_mask_0 = None block_diff = block_max - block_max_cummax + # Exclude padded query rows from the keep decision. _reshape_to_blocks + # pads the last block row with ``dtype.min``; a fully-padded row then + # has ``block_diff == 0`` (min - min), which passes ``> log_threshold`` + # and forces every block in the last partial block row to be kept + # (never skipped) — under-counting sparsity by up to one block row. + # Mask those rows to -inf so they vote "skip", matching the Triton + # kernel, which drops padding rows from its tile-skip reduction. + pad_q = padded_seq_q - seq_q + if pad_q > 0: + block_diff[:, :, -1, self.br - pad_q :, :] = float("-inf") for i, log_threshold in enumerate(log_thresholds): block_mask = (block_diff > log_threshold).any(dim=-2) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index c0a183787dd..a3109d56b73 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -49,6 +49,13 @@ def __init__(self, method_config=None): self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) # Calibration state self._threshold_trials: list[float] | None = None + # HF (modelopt_triton) backend calibration outputs, accumulated across + # attention calls in one forward pass and read back in + # ``_collect_calibration_stats``. The HF backend reads/writes these + # directly on the method instance (no thread-local needed). + self._hf_calibration_counters: torch.Tensor | None = None + self._hf_calibration_seq_k: int | None = None + self._hf_calibration_is_decode: bool = False # Runtime sparsity measurement self._measure_sparsity: bool = False self._sparsity_total: int = 0 @@ -111,6 +118,11 @@ def _triton_inference_context(self, module): def _triton_calibration_context(self, module): """Calibration: collect multi-threshold sparsity stats via Triton kernel.""" module._apply_skip_softmax = True + # Reset the HF-backend calibration accumulators for this forward pass. + # (The diffusers/LTX backends reset their own state in ``_set_triton_backends``.) + self._hf_calibration_counters = None + self._hf_calibration_seq_k = None + self._hf_calibration_is_decode = False self._set_triton_backends(calibration_mode=True, threshold_trials=self._threshold_trials) with self._get_diffusers_backend_context(): try: @@ -121,20 +133,20 @@ def _triton_calibration_context(self, module): module._apply_skip_softmax = False self._clear_triton_backends() - def _get_scale_factor(self) -> float | None: - """Compute scale_factor from calibration params, or None if uncalibrated. + def _get_scale_factor(self, phase: str = "prefill") -> float | None: + """Compute the scale_factor for ``phase`` from calibration params, or None. - The scale_factor is sequence-length-independent. Backends divide by the + The scale_factor is sequence-length-independent. Callers divide by the actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``. """ if self.calibration_params and self.target_sparse_ratio: import math import warnings - params = self.calibration_params.get("prefill", {}) + params = self.calibration_params.get(phase, {}) a = params.get("a", 0) b = params.get("b", 0) - target = self.target_sparse_ratio.get("prefill", 0.5) + target = self.target_sparse_ratio.get(phase, 0.5) if a > 0 and b > 0: # Warn if target is outside the calibrated range min_s = params.get("min_observed_sparsity") @@ -155,6 +167,22 @@ def _get_scale_factor(self) -> float | None: return a * math.exp(b * target) return None + def get_inference_threshold(self, seq_q: int, seq_k: int) -> float | None: + """Return the skip threshold to apply for this call's phase. + + Picks the phase from the query length (``decode`` when ``seq_q == 1``, + else ``prefill``) and returns the calibrated dynamic threshold + ``scale_factor(phase) / seq_k`` when the phase is calibrated, otherwise + the static ``skip_softmax_threshold`` (or ``None`` to disable). This is + what the HF backend applies; it keeps prefill and decode on their own + calibrated ``(a, b)`` instead of forcing decode onto prefill's. + """ + phase = "decode" if seq_q <= 1 else "prefill" + scale_factor = self._get_scale_factor(phase) + if scale_factor is not None and seq_k > 0: + return scale_factor / seq_k + return self.skip_softmax_threshold or None + @staticmethod @contextmanager def _get_diffusers_backend_context(): @@ -170,7 +198,12 @@ def _get_diffusers_backend_context(): yield def _set_triton_backends(self, **kwargs): - """Set config on both diffusers and LTX Triton backends.""" + """Set config on the diffusers and LTX Triton backends. + + The HF (modelopt_triton) backend reads its calibration config directly + from this method instance during ``triton_attention_forward``, so it + needs no separate configuration here. + """ try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( set_triton_skip_softmax_config, @@ -189,7 +222,7 @@ def _set_triton_backends(self, **kwargs): pass def _clear_triton_backends(self): - """Clear config on both Triton backends.""" + """Clear config on the diffusers and LTX Triton backends.""" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( clear_triton_skip_softmax_config, @@ -211,6 +244,9 @@ def _collect_calibration_stats(self, module): """Read Triton calibration counters and store as stats on the module.""" counters = None seq_k = None + # Diffusers/LTX (video) backends are prefill-only; only the HF backend + # reports a phase, for decode-step calibration. + phase = "prefill" try: from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( @@ -235,6 +271,14 @@ def _collect_calibration_stats(self, module): except ImportError: pass + if counters is None: + # HF (modelopt_triton) backend accumulates counters on this method + # instance (``module._sparse_method_instance is self``). + counters = self._hf_calibration_counters + seq_k = self._hf_calibration_seq_k + if counters is not None and self._hf_calibration_is_decode: + phase = "decode" + if counters is None or self._threshold_trials is None: return @@ -251,7 +295,7 @@ def _collect_calibration_stats(self, module): module._last_stats = { "sparsity": sparsity_list, "sample_length": sample_length, - "phase": "prefill", + "phase": phase, } def get_threshold_info(self) -> dict: diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py index 734755e3bba..d5c7f63e62c 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -27,7 +27,10 @@ live in ``plugins/sparse_attn_config.py`` and are unit-testable without vLLM. """ +import functools +import inspect import math +import os import warnings import torch @@ -38,6 +41,7 @@ ) from modelopt.torch.kernels.common.attention.triton_fa import attention as triton_attention +from modelopt.torch.kernels.sparsity.attention.calibrate import attention_calibrate def _target_sparse_ratio_for_phase(target_sparse_ratio, phase: str) -> float: @@ -113,7 +117,198 @@ def _build_sparse_kw(layer_cfg: dict) -> dict: return sparse_kw -class ModelOptSparseAttentionImpl(FlashAttentionImpl): +class _SparseCalibrationMixin: + """Backend-agnostic skip-softmax calibration shared by the sparse impls. + + A backend-specific impl extracts the dense paged metadata (per-request query + offsets/lengths, KV lengths, block table) and the K/V caches from its own + attention-metadata and cache layout, then calls :meth:`_forward_calibrate`. + The per-request measurement, dense-output write, and stats recording are + identical across backends (FlashAttention, FlashInfer, ...), so only the + extraction differs. ``iter_sparse_impls`` recognizes any impl that mixes + this in. + + Calibration state (``_calibrate``, ``_calib_threshold_trials``, + ``_calib_records``) is attached by :func:`enable_calibration`. + """ + + # Provided at runtime by the vLLM AttentionImpl base class. + scale: float + num_kv_heads: int + head_size: int + # Per-layer sparse kwargs set by the worker (empty during calibration). + sparse_kw: dict + # Attached by enable_calibration(). + _calib_threshold_trials: list[float] | None + _calib_records: list[dict] + + def _forward_calibrate( + self, + *, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + page_size: int, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + seq_lens: torch.Tensor, + block_table: torch.Tensor, + num_actual_tokens: int, + output: torch.Tensor, + ) -> torch.Tensor: + """Measure per-request tile-skip stats via the paged Triton calibration kernel. + + Each scheduled request is calibrated independently (batch=1) so its KV + length is the per-sample length the exponential fit needs, and so the + kernel keeps the uniform-length contract it was validated against. The + kernel computes full attention, so ``output`` is written densely and the + forward pass is numerically unchanged. + + Phase and causality are decided per request: ``q_len == 1`` is a decode + step (full-cache, non-causal); ``q_len > 1`` is (chunked) prefill (causal + — the kernel offsets the query into the KV span). A mixed prefill/decode + batch therefore contributes correctly to both phase fits. + """ + # FlashInfer stores the fp8 KV cache as raw uint8 bytes. Reinterpret it as + # fp8 (e4m3) so the kernel upcasts the true fp8 values; a plain cast of the + # uint8 bytes would read them as integers 0-255 and collapse the scores + # (over-reporting sparsity). KV/Q fp8 scales are 1.0 here, so the bitcast is + # the complete dequant. No-op for bf16 KV caches. + if key_cache.dtype == torch.uint8: + key_cache = key_cache.view(torch.float8_e4m3fn) + if value_cache.dtype == torch.uint8: + value_cache = value_cache.view(torch.float8_e4m3fn) + + trials = self._calib_threshold_trials + batch = seq_lens.shape[0] + + q = query[:num_actual_tokens].contiguous() + # Dummy K/V: in paged mode KV is read from the cache via block_table. + # Only shape[1] (num_kv_heads) is consulted, to compute the GQA ratio. + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + + for i in range(batch): + q_len = int(b_seq_len[i].item()) + if q_len <= 0: + continue + q_start = int(b_start_loc[i].item()) + seq_k = int(seq_lens[i].item()) + phase = "decode" if q_len <= 1 else "prefill" + + oi, counters = attention_calibrate( + q[q_start : q_start + q_len], + k_dummy, + k_dummy, + b_start_loc=torch.zeros(1, device=q.device, dtype=torch.int32), + b_seq_len=b_seq_len[i : i + 1].to(torch.int32), + max_input_len=q_len, + is_causal=q_len > 1, + softmax_scale=self.scale, + b_seq_len_k=seq_lens[i : i + 1].to(torch.int32), + max_input_len_k=seq_k, + threshold_trials=trials, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table[i : i + 1], + page_size=page_size, + ) + output[q_start : q_start + q_len] = oi + + total = counters[:, 0].float() + skipped = counters[:, 1].float() + sparsity = (skipped / total.clamp(min=1)).tolist() + self._calib_records.append( + {"phase": phase, "sample_length": seq_k, "sparsity": sparsity} + ) + + return output + + def _forward_sparse( + self, + *, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + page_size: int, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + seq_lens: torch.Tensor, + block_table: torch.Tensor, + num_actual_tokens: int, + max_query_len: int, + max_seq_len: int, + is_decode_only: bool, + is_causal: bool, + output: torch.Tensor, + dense_fallback, + ) -> torch.Tensor: + """Run the ModelOpt sparse Triton kernel over the paged cache, or delegate. + + Shared inference path across backends. The backend impl extracts the + per-request query offsets/lengths, KV lengths, and block table from its + own metadata and the K/V caches from its own layout, then calls this. + ``dense_fallback`` is a zero-arg callable that runs the backend's native + (dense) attention; it is used when no sparse feature applies to the + launch (decode-only skip-softmax, or a launch where dynamic calibration + disabled sparsity). + """ + sparse_kw = dict(getattr(self, "sparse_kw", {})) + _resolve_skip_softmax_calibration( + sparse_kw, is_prefill=not is_decode_only, max_seq_len=max_seq_len + ) + if is_decode_only: + # N:M sparse softmax is prefill-only. + for name in ("sparsity_n", "sparsity_m", "dense_sink_tokens", "dense_recent_tokens"): + sparse_kw.pop(name, None) + if set(sparse_kw) <= {"skip_softmax_threshold"}: + # Decode-only skip-softmax is not validated on the paged kernel + # yet; keep decode on the backend's native attention. + return dense_fallback() + if not sparse_kw: + # Dynamic calibration can disable sparse work for a launch (e.g. a + # short-prefill threshold outside the valid lambda range). + return dense_fallback() + + q = query[:num_actual_tokens].contiguous() + # Dummy K/V: paged mode reads KV from the cache via block_table; only + # shape[1] (num_kv_heads) is consulted, for the GQA ratio. + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + # Opt-in: count skipped tiles so the realized sparsity at serve time is + # observable (the kernel attaches _sparsity_total / _sparsity_skipped). + measure = bool(os.environ.get("MODELOPT_MEASURE_SPARSITY")) + triton_out = triton_attention( + q, + k=k_dummy, + v=k_dummy, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=max_query_len, + is_causal=is_causal, + softmax_scale=self.scale, + b_start_loc_k=None, # paged mode: KV offsets not needed + b_seq_len_k=seq_lens, + max_input_len_k=max_seq_len, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + page_size=page_size, + measure_sparsity=measure, + **sparse_kw, + ) + if measure and hasattr(triton_out, "_sparsity_total"): + total = triton_out._sparsity_total + skipped = triton_out._sparsity_skipped + frac = skipped / max(total, 1) + phase = "decode" if is_decode_only else "prefill" + print( + f"[ModelOpt] skip-softmax {phase}: {skipped}/{total} tiles skipped " + f"({frac:.1%}), seqlen={max_seq_len}, threshold={sparse_kw.get('skip_softmax_threshold')}" + ) + output[:num_actual_tokens] = triton_out + return output + + +class ModelOptSparseAttentionImpl(_SparseCalibrationMixin, FlashAttentionImpl): """Attention implementation that uses the ModelOpt Triton kernel. Inherits from FlashAttentionImpl to reuse: @@ -198,38 +393,42 @@ def forward( key_cache, value_cache = kv_cache.unbind(0) page_size = key_cache.shape[1] - # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) - sparse_kw = dict(getattr(self, "sparse_kw", {})) - _resolve_skip_softmax_calibration( - sparse_kw, - is_prefill=not is_decode_only, + # Calibration mode: measure multi-threshold tile-skip statistics with the + # Triton calibration kernel (full attention + counting) instead of running + # the sparse inference kernel. Output stays dense so generation proceeds + # normally and decode-step calibration sees a correct cache. + if getattr(self, "_calibrate", False) and getattr(self, "_calib_threshold_trials", None): + return self._forward_calibrate( + query=query, + key_cache=key_cache, + value_cache=value_cache, + page_size=page_size, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + seq_lens=seq_lens, + block_table=attn_metadata.block_table, + num_actual_tokens=num_actual_tokens, + output=output, + ) + + # Sparse prefill via the ModelOpt Triton kernel; delegate non-sparse and + # decode-only-skip-softmax launches back to vLLM FlashAttention. + return self._forward_sparse( + query=query, + key_cache=key_cache, + value_cache=value_cache, + page_size=page_size, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + seq_lens=seq_lens, + block_table=attn_metadata.block_table, + num_actual_tokens=num_actual_tokens, + max_query_len=attn_metadata.max_query_len, max_seq_len=attn_metadata.max_seq_len, - ) - if is_decode_only: - # N:M sparse softmax is prefill-only. - for name in ("sparsity_n", "sparsity_m", "dense_sink_tokens", "dense_recent_tokens"): - sparse_kw.pop(name, None) - if set(sparse_kw) <= {"skip_softmax_threshold"}: - # The current ModelOpt paged kernel is only validated for - # sparse prefill in vLLM. Decode-only skip-softmax would route - # through the dense Triton path for every non-skipped tile, so - # keep decode on vLLM FlashAttention until that path is covered. - return self._forward_vllm_flash_attn( - layer, - query, - key, - value, - kv_cache, - attn_metadata, - output, - output_scale, - output_block_scale, - ) - if not sparse_kw: - # Dynamic calibration can disable sparse work for a launch, e.g. - # short-prefill thresholds outside the valid lambda range. Avoid - # swapping in the ModelOpt dense kernel when no sparse feature is active. - return self._forward_vllm_flash_attn( + is_decode_only=is_decode_only, + is_causal=is_causal, + output=output, + dense_fallback=lambda: self._forward_vllm_flash_attn( layer, query, key, @@ -239,44 +438,9 @@ def forward( output, output_scale, output_block_scale, - ) - - # Prepare metadata for our kernel - q = query[:num_actual_tokens].contiguous() - # Dummy K/V for paged mode: not used by the kernel (KV are read from - # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads - # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). - k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) - - # Call ModelOpt Triton kernel with paged KV. - # b_seq_len is the query length (e.g., 6 for prefill, 1 for decode). - # b_seq_len_k is the total KV length including cache (e.g., 6 for first - # prefill, 7/8/... for subsequent decode steps). - triton_out = triton_attention( - q, - k=k_dummy, - v=k_dummy, - # Query metadata - b_start_loc=b_start_loc, - b_seq_len=b_seq_len, - max_input_len=attn_metadata.max_query_len, - is_causal=is_causal, - softmax_scale=self.scale, - # KV metadata - b_start_loc_k=None, # paged mode: KV offsets not needed - b_seq_len_k=seq_lens, # total KV length per sequence - max_input_len_k=attn_metadata.max_seq_len, - # Paged KV cache - k_cache=key_cache, # [num_blocks, page_size, num_kv_heads, head_dim] - v_cache=value_cache, # [num_blocks, page_size, num_kv_heads, head_dim] - block_table=attn_metadata.block_table, # [batch, max_blocks] - page_size=page_size, # tokens per page in the KV cache - **sparse_kw, + ), ) - output[:num_actual_tokens] = triton_out - return output - class ModelOptSparseAttentionBackend(FlashAttentionBackend): """Attention backend that uses ModelOpt's sparse Triton kernel. @@ -295,12 +459,18 @@ def get_impl_cls() -> type: return ModelOptSparseAttentionImpl -def _clone_sparse_impl(old_impl): - """Create a sparse impl while preserving vLLM's initialized runtime state.""" +def _clone_sparse_impl(old_impl, new_cls: type = ModelOptSparseAttentionImpl): + """Re-class a vLLM attention impl into ``new_cls``, preserving its state. + + The new impl shares the backend impl's initialized runtime state (config, + scales, kv-cache dtype) so ``do_kv_cache_update`` and the dense-fallback + ``super().forward()`` keep working. ``new_cls`` selects the backend-specific + sparse impl (FlashAttention vs FlashInfer). + """ if getattr(old_impl, "sinks", None) is not None: # vLLM passes sinks to FlashAttention as s_aux; our Triton path does not support sinks yet. raise NotImplementedError( - "ModelOptSparseAttentionImpl does not support vLLM FlashAttention sinks yet." + f"{new_cls.__name__} does not support vLLM FlashAttention sinks yet." ) try: @@ -310,6 +480,322 @@ def _clone_sparse_impl(old_impl): "Cannot clone vLLM attention impl state: old impl does not expose __dict__." ) from err - new_impl = ModelOptSparseAttentionImpl.__new__(ModelOptSparseAttentionImpl) + new_impl = object.__new__(new_cls) new_impl.__dict__.update(old_state) return new_impl + + +# --------------------------------------------------------------------------- +# FlashInfer backend support +# --------------------------------------------------------------------------- +# FlashInfer's per-step metadata only retains planned wrappers, not the dense +# block_table / seq_lens / query_start_loc the calibration kernel needs. Those +# live on the CommonAttentionMetadata the builder consumes, so we stash them onto +# the produced FlashInferMetadata (``_modelopt_*``) and read them back in +# forward. The KV cache is ``[num_blocks, 2, page_size, num_kv_heads, head_dim]`` +# (``[:, 0]`` = K, ``[:, 1]`` = V); strides are passed through, so this is correct +# for both NHD and HND physical layouts. + +_FLASHINFER_PATCHED = False +_FLASHINFER_IMPL_CLS: type | None = None + + +def patch_flashinfer_metadata_builder() -> bool: + """Stash the dense common metadata onto ``FlashInferMetadata`` at build time. + + Idempotent. Returns ``True`` if the FlashInfer builder is now patched, + ``False`` if the FlashInfer backend is unavailable. + """ + global _FLASHINFER_PATCHED + if _FLASHINFER_PATCHED: + return True + try: + from vllm.v1.attention.backends.flashinfer import FlashInferMetadataBuilder + except ImportError: + return False + + orig_build = FlashInferMetadataBuilder.build + # Locate ``common_attn_metadata`` by parameter name so the wrapper is robust + # to the builder's positional signature (this vLLM build is + # ``build(self, common_prefix_len, common_attn_metadata, fast_build=False)``). + # Pass ``*args``/``**kwargs`` straight through to avoid re-binding positional + # args (re-passing common_attn_metadata first collided with common_prefix_len). + build_sig = inspect.signature(orig_build) + + @functools.wraps(orig_build) + def build(*args, **kwargs): + metadata = orig_build(*args, **kwargs) + common = build_sig.bind(*args, **kwargs).arguments["common_attn_metadata"] + metadata._modelopt_block_table = common.block_table_tensor + metadata._modelopt_seq_lens = common.seq_lens + metadata._modelopt_query_start_loc = common.query_start_loc + metadata._modelopt_num_actual_tokens = common.num_actual_tokens + metadata._modelopt_max_query_len = common.max_query_len + metadata._modelopt_max_seq_len = common.max_seq_len + return metadata + + FlashInferMetadataBuilder.build = build + _FLASHINFER_PATCHED = True + return True + + +def get_flashinfer_sparse_impl_cls() -> type: + """Build (once) and return ``ModelOptSparseFlashInferImpl``. + + Defined lazily so importing this module does not require the FlashInfer + backend (and its ``flashinfer`` dependency) to be installed. + """ + global _FLASHINFER_IMPL_CLS + if _FLASHINFER_IMPL_CLS is not None: + return _FLASHINFER_IMPL_CLS + + from vllm.v1.attention.backends.flashinfer import FlashInferImpl + + class ModelOptSparseFlashInferImpl(_SparseCalibrationMixin, FlashInferImpl): + """FlashInfer attention impl with ModelOpt skip-softmax calibration + serving. + + With the dense paged metadata stashed by + ``patch_flashinfer_metadata_builder`` available, it either: + + - **calibration mode** (``enable_calibration``): measures multi-threshold + tile-skip stats over the paged cache via the Triton calibration kernel + (dense output), or + - **inference**: runs the ModelOpt sparse Triton kernel for sparse prefill + launches, reading FlashInfer's ``[num_blocks, 2, page, ...]`` cache + (``[:, 0]`` = K, ``[:, 1]`` = V). + + Profiling (``attn_metadata is None``), cascade, an unpatched builder, or a + launch with no active sparse feature fall back to native FlashInfer + (``super().forward``) — mirroring ``ModelOptSparseAttentionImpl``. + """ + + def forward( + self, + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output=None, + output_scale=None, + output_block_scale=None, + ): + """Calibrate / sparse-serve via the Triton kernel; delegate otherwise.""" + + def dense(): + return FlashInferImpl.forward( + self, + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output, + output_scale, + output_block_scale, + ) + + # Native FlashInfer for profiling, cascade, or an unpatched builder + # (the dense paged metadata the Triton kernel needs is unavailable). + if ( + attn_metadata is None + or getattr(attn_metadata, "use_cascade", False) + or not hasattr(attn_metadata, "_modelopt_block_table") + ): + return dense() + + assert output is not None, "Output tensor must be provided." + key_cache = kv_cache[:, 0] + value_cache = kv_cache[:, 1] + page_size = key_cache.shape[1] + seq_lens = attn_metadata._modelopt_seq_lens + cu_seqlens_q = attn_metadata._modelopt_query_start_loc + batch = seq_lens.shape[0] + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + block_table = attn_metadata._modelopt_block_table + num_actual_tokens = attn_metadata._modelopt_num_actual_tokens + + if getattr(self, "_calibrate", False) and getattr( + self, "_calib_threshold_trials", None + ): + return self._forward_calibrate( + query=query, + key_cache=key_cache, + value_cache=value_cache, + page_size=page_size, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + seq_lens=seq_lens, + block_table=block_table, + num_actual_tokens=num_actual_tokens, + output=output, + ) + + max_query_len = attn_metadata._modelopt_max_query_len + is_decode_only = max_query_len <= 1 + return self._forward_sparse( + query=query, + key_cache=key_cache, + value_cache=value_cache, + page_size=page_size, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + seq_lens=seq_lens, + block_table=block_table, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + max_seq_len=attn_metadata._modelopt_max_seq_len, + is_decode_only=is_decode_only, + is_causal=not is_decode_only, + output=output, + dense_fallback=dense, + ) + + _FLASHINFER_IMPL_CLS = ModelOptSparseFlashInferImpl + return _FLASHINFER_IMPL_CLS + + +def select_sparse_impl_cls(impl) -> type | None: + """Return the ModelOpt sparse impl class for a vLLM attention impl's backend. + + ``None`` if ``impl`` is already a ModelOpt sparse impl or its backend is + unsupported. For FlashInfer it also installs the metadata-builder patch that + exposes the dense paged metadata the Triton kernel needs. Used by both the + serving and calibration workers to swap the right impl per attention layer. + """ + if isinstance(impl, _SparseCalibrationMixin): + return None # already swapped (idempotent across reloads) + name = type(impl).__name__ + if name == "FlashAttentionImpl": + return ModelOptSparseAttentionImpl + if name == "FlashInferImpl": + return get_flashinfer_sparse_impl_cls() if patch_flashinfer_metadata_builder() else None + return None + + +# --------------------------------------------------------------------------- +# Calibration driver helpers +# --------------------------------------------------------------------------- +# These run skip-softmax calibration *through* the vLLM integration: the model +# is loaded under vLLM with ModelOptSparseAttentionImpl on each attention layer +# (see examples/vllm_serve/sparse_attn_worker.py), calibration mode is turned on, +# a few prompts are generated, and the collected per-threshold tile-skip counts +# are fit to the same exponential model (a, b) the HF path produces — so the +# result drops straight into the existing export/inference path. + + +def iter_sparse_impls(model): + """Yield every ModelOpt sparse attention impl reachable from a vLLM model. + + Walks ``model.named_modules()`` and returns the swapped ``impl`` of each + attention layer (any backend — FlashAttention, FlashInfer — that mixes in + ``_SparseCalibrationMixin``). Used by the calibration driver to toggle + calibration mode and harvest stats without knowing vLLM's module layout. + """ + for _, module in model.named_modules(): + impl = getattr(module, "impl", None) + if isinstance(impl, _SparseCalibrationMixin): + yield impl + + +def enable_calibration(impls, threshold_trials: list[float]) -> None: + """Put a set of sparse impls into calibration mode and clear prior records.""" + if not threshold_trials: + raise ValueError("threshold_trials must be a non-empty list for calibration.") + for impl in impls: + impl._calibrate = True + impl._calib_threshold_trials = list(threshold_trials) + impl._calib_records = [] + + +def disable_calibration(impls) -> None: + """Turn off calibration mode (collected records are left intact).""" + for impl in impls: + impl._calibrate = False + + +def collect_calibration_stats(impls) -> dict[str, list[dict]]: + """Aggregate per-request records into per-sample stats, matching the HF path. + + Mirrors :meth:`DynamicThresholdCalibrator._extract_calibration_stats`: the + per-threshold sparsity is **averaged across layers** for each sample, yielding + one record per sample (not per ``(layer, sample)``). During calibration every + attention layer processes the same launches in the same order, so each layer's + ``_calib_records`` are aligned by index — the k-th record of every layer is the + same ``(launch, request)`` sample. Records are grouped by phase first, so + prefill and decode samples aggregate separately; chunked prefill is supported + (each chunk launch is its own sample, exactly as in HF chunked calibration). + + Returns ``{"prefill": [...], "decode": [...]}`` where each entry is a + ``{"sample_length", "sparsity"}`` record ready for + :meth:`DynamicThresholdCalibrator.calibrate_from_stats`. + """ + # Per phase, gather each layer's ordered record list. + per_phase_layers: dict[str, list[list[dict]]] = {"prefill": [], "decode": []} + for impl in impls: + split: dict[str, list[dict]] = {"prefill": [], "decode": []} + for record in getattr(impl, "_calib_records", []): + split.setdefault(record["phase"], []).append(record) + for phase, records in split.items(): + if records: + per_phase_layers.setdefault(phase, []).append(records) + + out: dict[str, list[dict]] = {"prefill": [], "decode": []} + for phase, layer_lists in per_phase_layers.items(): + if not layer_lists: + continue + # Align by sample index across layers; guard against ragged layers. + num_samples = min(len(records) for records in layer_lists) + for i in range(num_samples): + per_layer_sparsity = [records[i]["sparsity"] for records in layer_lists] + num_thresholds = len(per_layer_sparsity[0]) + avg_sparsity = [ + sum(s[t] for s in per_layer_sparsity) / len(per_layer_sparsity) + for t in range(num_thresholds) + ] + out.setdefault(phase, []).append( + { + "sparsity": avg_sparsity, + "sample_length": layer_lists[0][i]["sample_length"], + } + ) + return out + + +def fit_calibration( + impls, + threshold_trials: list[float], + *, + fit_logspace: bool = False, +) -> dict[str, dict[str, float]]: + """Fit the exponential skip-softmax model from collected vLLM stats. + + Reuses :class:`DynamicThresholdCalibrator` so the vLLM-calibrated ``(a, b)`` + are identical in form to the HF path and export unchanged via + ``threshold_scale_factor``. + + Returns: + ``{phase: {"a", "b", "min_observed_sparsity", "max_observed_sparsity"}}`` + for each phase that produced a valid fit. + """ + from ..calibration.calibrator import DynamicThresholdCalibrator + + per_phase = collect_calibration_stats(impls) + calibration_params: dict[str, dict[str, float]] = {} + for phase, stats in per_phase.items(): + if not stats: + continue + calibrator = DynamicThresholdCalibrator( + threshold_trials=list(threshold_trials), fit_logspace=fit_logspace + ) + result = calibrator.calibrate_from_stats(stats, phase=phase) + if "a" in result and "b" in result: + params = {"a": result["a"], "b": result["b"]} + for key in ("min_observed_sparsity", "max_observed_sparsity"): + if key in result: + params[key] = result[key] + calibration_params[phase] = params + return calibration_params diff --git a/tests/gpu/torch/kernels/conftest.py b/tests/gpu/torch/kernels/conftest.py index fa4f6177143..75b585e604d 100644 --- a/tests/gpu/torch/kernels/conftest.py +++ b/tests/gpu/torch/kernels/conftest.py @@ -36,6 +36,39 @@ def make_varlen_meta(seq_lens, device="cuda"): return b_start_loc, b_seq_len +def scatter_to_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache + block table. + + Returns ``(k_cache, v_cache, block_table)`` where the caches are shaped + ``[num_blocks, page_size, num_kv_heads, head_dim]`` and ``block_table`` is + ``[batch, max_blocks_per_seq]`` — the layout the paged Triton kernels read. + """ + batch = b_seq_len.shape[0] + device, dtype = k.device, k.dtype + + blocks_per_seq = [(int(b_seq_len[b].item()) + page_size - 1) // page_size for b in range(batch)] + num_blocks = sum(blocks_per_seq) + max_blocks = max(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + g = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = g + ts = blk * page_size + te = min(ts + page_size, slen) + n = te - ts + k_cache[g, :n] = k[start + ts : start + te] + v_cache[g, :n] = v[start + ts : start + te] + g += 1 + return k_cache, v_cache, block_table + + def sdpa_reference(q, k, v, b_start_loc, b_seq_len, is_causal=True): """SDPA reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" batch = b_seq_len.shape[0] diff --git a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py index fe16559a187..8cb8774f24d 100644 --- a/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py @@ -25,7 +25,7 @@ import pytest import torch -from conftest import make_qkv, make_varlen_meta +from conftest import make_qkv, make_varlen_meta, scatter_to_paged_cache pytestmark = [ pytest.mark.filterwarnings("ignore::UserWarning"), @@ -132,6 +132,43 @@ def test_different_seq_q_seq_k(self): assert out.shape == q.shape assert counters.shape == (2, 2) + def test_decode_skips_padding_rows(self): + """Decode (seq_q=1) skips real KV tiles once padding Q rows are excluded. + + With BLOCK_M=128, 127/128 query rows are padding. Before the padding-row + fix their ~0 gap forced zero skips; after it the largest threshold skips a + meaningful number of KV tiles. + """ + seq_q, seq_k, num_heads, head_dim = 1, 512, 4, 64 + scale = 1.0 / (head_dim**0.5) + torch.manual_seed(0) + q = torch.randn(seq_q, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(seq_k, num_heads, head_dim, device="cuda", dtype=torch.float16) + b_start_loc = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len = torch.ones(1, device="cuda", dtype=torch.int32) + b_start_loc_k = torch.zeros(1, device="cuda", dtype=torch.int32) + b_seq_len_k = torch.full((1,), seq_k, device="cuda", dtype=torch.int32) + + _, counters = attention_calibrate( + q, + k, + v, + b_start_loc, + b_seq_len, + seq_q, + softmax_scale=scale, + is_causal=False, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=seq_k, + threshold_trials=[1e-2, 1e-1, 5e-1, 9e-1], + ) + skipped = counters[:, 1] + assert (skipped[1:] >= skipped[:-1]).all() # monotonic non-decreasing + assert (skipped <= counters[:, 0]).all() + assert skipped[-1] > 0 # padding-row fix makes this non-zero + def test_threshold_order_doesnt_affect_counts(self): """Skipped counts at the same threshold are independent of trial ordering.""" q, k, v, locs, lens = self._make_inputs() @@ -208,6 +245,385 @@ def test_threshold_semantics_match_runtime_counts(self): assert counters[0, 1].item() == out._sparsity_skipped +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestAttentionCalibratePaged: + """Paged KV cache calibration must match the contiguous reference exactly. + + This is the path the vLLM integration calibrates through: KV lives in a + paged cache addressed by a block table rather than in contiguous tensors. + """ + + def test_prefill_paged_matches_contiguous(self): + """Causal prefill: paged counters and output equal the contiguous run.""" + seq, num_heads, num_kv_heads, head_dim, page_size = 384, 4, 2, 64, 16 + scale = 1.0 / (head_dim**0.5) + trials = [1e-4, 1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + + torch.manual_seed(0) + # A dominant sink at position 0 (q·k[0] huge, all other scores ~0) makes + # later KV tiles negligible, so later query tiles skip them — gives nonzero + # counters to compare, beyond the trivially-equal all-dense case. + q = torch.ones(seq, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros(seq, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + k[0] = 20.0 + v = torch.randn(seq, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + locs, lens = make_varlen_meta([seq]) + + out_ref, c_ref = attention_calibrate( + q, k, v, locs, lens, seq, softmax_scale=scale, is_causal=True, threshold_trials=trials + ) + + k_cache, v_cache, block_table = scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + k_dummy = torch.empty(0, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + out_pg, c_pg = attention_calibrate( + q, + k_dummy, + k_dummy, + locs, + lens, + seq, + softmax_scale=scale, + is_causal=True, + b_seq_len_k=lens, + max_input_len_k=seq, + threshold_trials=trials, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + assert torch.equal(c_ref, c_pg), (c_ref.tolist(), c_pg.tolist()) + assert c_pg[-1, 1] > 0 # the sink makes some tiles skippable + torch.testing.assert_close(out_pg, out_ref, rtol=5e-3, atol=5e-3) + + def test_decode_paged_matches_contiguous(self): + """Decode (seq_q=1) against a long paged cache equals the contiguous run.""" + seq_k, num_heads, num_kv_heads, head_dim, page_size = 2048, 4, 2, 64, 16 + scale = 1.0 / (head_dim**0.5) + trials = [1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + + torch.manual_seed(1) + q = torch.randn(1, num_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + locs_q = torch.zeros(1, device="cuda", dtype=torch.int32) + len_q = torch.ones(1, device="cuda", dtype=torch.int32) + locs_k = torch.zeros(1, device="cuda", dtype=torch.int32) + len_k = torch.full((1,), seq_k, device="cuda", dtype=torch.int32) + + out_ref, c_ref = attention_calibrate( + q, + k, + v, + locs_q, + len_q, + 1, + softmax_scale=scale, + is_causal=False, + b_start_loc_k=locs_k, + b_seq_len_k=len_k, + max_input_len_k=seq_k, + threshold_trials=trials, + ) + + k_cache, v_cache, block_table = scatter_to_paged_cache( + k, v, locs_k, len_k, num_kv_heads, head_dim, page_size + ) + k_dummy = torch.empty(0, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + out_pg, c_pg = attention_calibrate( + q, + k_dummy, + k_dummy, + locs_q, + len_q, + 1, + softmax_scale=scale, + is_causal=False, + b_seq_len_k=len_k, + max_input_len_k=seq_k, + threshold_trials=trials, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + assert torch.equal(c_ref, c_pg), (c_ref.tolist(), c_pg.tolist()) + # Full cache scanned: total == num_heads * ceil(seq_k / 128). + assert int(c_pg[0, 0]) == num_heads * (seq_k // 128) + torch.testing.assert_close(out_pg, out_ref, rtol=5e-3, atol=5e-3) + + def test_paged_requires_block_table(self): + """Passing a cache without a block table is a hard error, not a silent run.""" + q, k, v = make_qkv(256, 4, 4, 64, dtype=torch.float16) + locs, lens = make_varlen_meta([256]) + k_cache, v_cache, _ = scatter_to_paged_cache(k, v, locs, lens, 4, 64, 16) + with pytest.raises(ValueError, match="block_table"): + attention_calibrate( + q, + k, + v, + locs, + lens, + 256, + softmax_scale=1.0 / (64**0.5), + is_causal=False, + threshold_trials=[1e-2], + k_cache=k_cache, + v_cache=v_cache, + ) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestCalibrateVsPytorchReference: + """The Triton calibration kernel must measure the same sparsity as PyTorch. + + ``attention_calibrate`` (contiguous and paged) and the PyTorch + ``flash_skip_softmax`` calibration both use 128x128 block-level skip logic + (keep a block iff some query row's block-max stays within ``log(threshold)`` + of the running max). This is the contract that lets vLLM calibration produce + the same ``(a, b)`` as the established PyTorch path — assert the per-threshold + skipped-tile fractions agree on identical inputs. + """ + + _TRIALS = [1e-3, 1e-2, 5e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + @staticmethod + def _pytorch_sparsity(q4, k4, v4, trials, scale, is_causal): + """Per-threshold skipped-block fraction from PyTorch flash_skip_softmax.""" + from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import ( + FlashSkipSoftmax, + ) + + seq_q, seq_k = q4.shape[2], k4.shape[2] + scores = torch.matmul(q4, k4.transpose(-2, -1)) * scale + if is_causal: + causal_mask = torch.triu(torch.ones(seq_q, seq_k, device=q4.device), diagonal=1).bool() + scores = scores.masked_fill(causal_mask[None, None], float("-inf")) + method = FlashSkipSoftmax( + method_config={ + "thresholds": {"prefill": trials, "decode": trials}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": is_causal, + } + ) + method._calibration_mode = True + method.thresholds = trials + _, stats = method.calc_correction_factor_and_p(scores, "prefill" if seq_q > 1 else "decode") + return stats["sparsity"] + + @staticmethod + def _triton_sparsity(counters): + return (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist() + + @staticmethod + def _graded_qkv(seq, num_heads, head_dim, seed): + """Localized-decay attention (sink + distance decay) -> graded sparsity.""" + torch.manual_seed(seed) + q4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + pos = torch.arange(seq, device="cuda").float() + decay = torch.exp(-pos / (seq * 0.15))[None, None, :, None] + k4 = (torch.randn(1, num_heads, seq, head_dim, device="cuda") * decay).to(torch.float16) + k4[:, :, 0] = 8.0 + v4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + return q4, k4, v4 + + def _triton_paged_sparsity(self, q4, k4, v4, trials, scale): + seq, num_heads, head_dim = q4.shape[2], q4.shape[1], q4.shape[3] + qf = q4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + kf = k4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + vf = v4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + locs, lens = make_varlen_meta([seq]) + k_cache, v_cache, block_table = scatter_to_paged_cache( + kf, vf, locs, lens, num_heads, head_dim, 16 + ) + k_dummy = torch.empty(0, num_heads, head_dim, device="cuda", dtype=torch.float16) + _, counters = attention_calibrate( + qf, + k_dummy, + k_dummy, + locs, + lens, + seq, + softmax_scale=scale, + is_causal=True, + b_seq_len_k=lens, + max_input_len_k=seq, + threshold_trials=trials, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=16, + ) + return self._triton_sparsity(counters) + + def test_fitted_ab_matches_pytorch(self): + """End-to-end: the fitted exponential (a, b) is the same for both paths. + + Measures per-length sparsity over several lengths with PyTorch + flash_skip_softmax and with the paged (vLLM) Triton kernel, fits each set + through DynamicThresholdCalibrator, and asserts the calibration results + (a, b) agree — the property that lets vLLM-calibrated checkpoints serve + identically to HF-calibrated ones. + """ + from modelopt.torch.sparsity.attention_sparsity.calibration.calibrator import ( + DynamicThresholdCalibrator, + ) + + num_heads, head_dim = 4, 64 + scale = 1.0 / (head_dim**0.5) + trials = [1e-3, 3e-3, 1e-2, 3e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + pt_stats, triton_stats = [], [] + # Non-128-multiple lengths so the partial last block-row exercises the + # padding-row handling (where flash previously diverged from the kernel). + for seed, seq in enumerate([500, 776, 1000, 1500, 2000]): + q4, k4, v4 = self._graded_qkv(seq, num_heads, head_dim, seed) + pt_stats.append( + { + "sparsity": self._pytorch_sparsity(q4, k4, v4, trials, scale, True), + "sample_length": seq, + } + ) + triton_stats.append( + { + "sparsity": self._triton_paged_sparsity(q4, k4, v4, trials, scale), + "sample_length": seq, + } + ) + + pt_fit = DynamicThresholdCalibrator(threshold_trials=trials).calibrate_from_stats( + pt_stats, "prefill" + ) + triton_fit = DynamicThresholdCalibrator(threshold_trials=trials).calibrate_from_stats( + triton_stats, "prefill" + ) + + # Both fits must succeed and agree (same measured sparsity -> same fit). + assert pt_fit and triton_fit, (pt_fit, triton_fit) + assert pt_fit["a"] == pytest.approx(triton_fit["a"], rel=1e-3) + assert pt_fit["b"] == pytest.approx(triton_fit["b"], rel=1e-3) + # Sanity: a real (non-degenerate) exponential fit on enough valid points. + assert pt_fit["a"] > 0 and pt_fit["b"] > 0 + assert pt_fit["num_data_points"] >= 10 + + def test_graded_prefill_matches_pytorch(self): + """Localized-decay attention sweeps sparsity 0->~0.7; all three paths agree. + + Compares PyTorch flash_skip_softmax, contiguous ``attention_calibrate``, + and the paged (vLLM) ``attention_calibrate`` — the graded sweep makes this + a discriminating test rather than a trivially-0% / 100% one. + + ``seq`` is deliberately *not* a multiple of the 128 block size: the last + query-block-row is partial, exercising the padding-row handling where the + flash block method previously diverged from the kernel (it counted + dtype-min-padded rows as "keep" and never skipped the last block row). + """ + num_heads, head_dim, seq, page_size = 4, 64, 1000, 16 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(2) + q4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + # Key norm decays with position (+ a sink at 0) so distant tiles fall below + # the threshold gradually as it grows — a smooth sparsity sweep. + pos = torch.arange(seq, device="cuda").float() + decay = torch.exp(-pos / (seq * 0.15))[None, None, :, None] + k4 = (torch.randn(1, num_heads, seq, head_dim, device="cuda") * decay).to(torch.float16) + k4[:, :, 0] = 8.0 + v4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + + pt = self._pytorch_sparsity(q4, k4, v4, self._TRIALS, scale, is_causal=True) + + qf = q4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + kf = k4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + vf = v4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + locs, lens = make_varlen_meta([seq]) + + _, c_contig = attention_calibrate( + qf, + kf, + vf, + locs, + lens, + seq, + softmax_scale=scale, + is_causal=True, + threshold_trials=self._TRIALS, + ) + + k_cache, v_cache, block_table = scatter_to_paged_cache( + kf, vf, locs, lens, num_heads, head_dim, page_size + ) + k_dummy = torch.empty(0, num_heads, head_dim, device="cuda", dtype=torch.float16) + _, c_paged = attention_calibrate( + qf, + k_dummy, + k_dummy, + locs, + lens, + seq, + softmax_scale=scale, + is_causal=True, + b_seq_len_k=lens, + max_input_len_k=seq, + threshold_trials=self._TRIALS, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + triton_contig = self._triton_sparsity(c_contig) + triton_paged = self._triton_sparsity(c_paged) + + # The sweep must actually exercise the intermediate (fit-relevant) range. + assert any(0.1 < s < 0.9 for s in pt), pt + # Paged is the vLLM path; it must equal the contiguous kernel exactly. + assert triton_paged == triton_contig, (triton_paged, triton_contig) + # Triton (both layouts) matches PyTorch flash_skip_softmax block-for-block. + for s_pt, s_tr in zip(pt, triton_contig): + assert abs(s_pt - s_tr) <= 0.02, (pt, triton_contig) + + def test_dominant_sink_matches_pytorch_exactly(self): + """A dominant sink puts gaps far from any threshold boundary -> exact match.""" + num_heads, head_dim, seq = 4, 64, 512 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(0) + q4 = torch.ones(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + k4 = torch.zeros(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + k4[:, :, 0] = 20.0 + v4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + + pt = self._pytorch_sparsity(q4, k4, v4, self._TRIALS, scale, is_causal=True) + + qf = q4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + kf = k4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + vf = v4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + locs, lens = make_varlen_meta([seq]) + _, counters = attention_calibrate( + qf, + kf, + vf, + locs, + lens, + seq, + softmax_scale=scale, + is_causal=True, + threshold_trials=self._TRIALS, + ) + triton = self._triton_sparsity(counters) + assert max(pt) > 0.0 # the sink makes blocks skippable + # Gaps are far from any threshold boundary, so the skipped-block counts + # are identical; only the fraction's fp repr differs (fp64 vs fp32), so a + # single-block disagreement (>= 1/40 = 0.025 here) would still fail. + for s_pt, s_tr in zip(pt, triton): + assert abs(s_pt - s_tr) < 1e-5, (pt, triton) + + @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") class TestMeasureSparsity: """Runtime sparsity counters during inference.""" @@ -282,7 +698,9 @@ def test_first_measured_call_has_real_tile_count_with_autotune(self): assert result.returncode == 0, result.stderr totals = [line for line in result.stdout.splitlines() if line.startswith("TOTAL=")] assert totals, result.stdout - assert int(totals[-1].split("=", maxsplit=1)[1]) == 8 + # seq_len=256, _MEASURE_BLOCK_M = _MEASURE_BLOCK_N = 128, non-causal: + # Q tiles = ceil(256/128) = 2, KV tiles = ceil(256/128) = 2, total = 4. + assert int(totals[-1].split("=", maxsplit=1)[1]) == 4 def test_measure_sparsity_without_skip_is_noop(self): """Without skip-softmax, measure_sparsity doesn't attach counters.""" diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py new file mode 100644 index 00000000000..949e67b2cd8 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for skip-softmax calibration via the Triton backend on HF models. + +These exercise the HuggingFace (``modelopt_triton``) wiring that routes the +calibration forward pass through the fused ``attention_calibrate`` kernel and +feeds the collected multi-threshold tile-skip statistics into the same +exponential-model fit used by the PyTorch path. +""" + +import copy +import itertools + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention.hf_triton_attention import triton_attention_forward +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_TRITON_CALIB +from modelopt.torch.sparsity.attention_sparsity.methods.triton_skip_softmax import ( + TritonSkipSoftmaxMethod, +) + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), +] + +THRESHOLD_TRIALS = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create a minimal Llama model directory.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama_triton_calib"), + num_hidden_layers=2, + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=1024, + ) + + +def _load_eager(tiny_llama_dir): + return AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, attn_implementation="eager", device_map="cuda" + ) + + +def _make_forward_loop(vocab_size, lengths=(128, 256, 384, 512)): + """Forward loop that runs several full-prefill passes of varying length. + + Each pass triggers one ``attention_calibrate`` call per layer, producing one + per-sample calibration record per length. + """ + + def forward_loop(model): + torch.manual_seed(0) + for seq_len in lengths: + input_ids = torch.randint(0, vocab_size, (1, seq_len), device="cuda") + with torch.no_grad(): + model(input_ids, use_cache=False) + + return forward_loop + + +def _calibration_module(threshold_trials): + """Build a bare module whose ``_sparse_method_instance`` is in calibration mode. + + The HF backend reads its calibration config from (and writes counters back + to) ``module._sparse_method_instance``, so this is the minimal stand-in for + driving ``triton_attention_forward`` through the calibration branch. + """ + method = TritonSkipSoftmaxMethod() + method.set_calibration_mode(True) + method._threshold_trials = threshold_trials + + module = torch.nn.Module() + module._sparse_method_instance = method + return module + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestTritonCalibrationHF: + """End-to-end calibration via the Triton backend on a tiny HF model.""" + + def test_calibrated_model_inference(self, tiny_llama_dir): + """SKIP_SOFTMAX_TRITON_CALIB dispatches to the Triton backend and the + calibrated model runs inference cleanly.""" + model = _load_eager(tiny_llama_dir) + config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) + # Prefill-only (custom forward_loop can't drive RULER decode calibration). + config["sparse_cfg"]["calibration"]["target_sparse_ratio"] = {"prefill": 0.5} + + forward_loop = _make_forward_loop(model.config.vocab_size) + sparse_model = mtsa.sparsify(model, config, forward_loop=forward_loop) + assert sparse_model.config._attn_implementation == "modelopt_triton" + + sparse_model.eval() + input_ids = torch.randint(0, model.config.vocab_size, (1, 64), device="cuda") + with torch.no_grad(): + out = sparse_model(input_ids, use_cache=False) + assert out.logits is not None + assert not torch.isnan(out.logits).any() + + def test_decode_branch_reports_decode_phase(self): + """The HF calibration branch routes decode-shaped calls through the kernel + and surfaces its counters as a ``decode``-phase stats record. + + This is the HF-only counter path in ``_collect_calibration_stats``; the + kernel's skip-count behavior itself is covered in the kernel test suite. + """ + num_heads, seq_k, head_dim = 4, 512, 64 + torch.manual_seed(0) + q = torch.randn(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + assert method._hf_calibration_is_decode is True + assert method._hf_calibration_counters is not None + + method._collect_calibration_stats(module) + assert module._last_stats["phase"] == "decode" + assert len(module._last_stats["sparsity"]) == len(THRESHOLD_TRIALS) + + def test_decode_calibration_measures_full_cache_with_sink(self): + """Decode calibration must scan the whole KV cache and report real sparsity. + + A dominant sink at position 0 makes the distant KV tiles negligible, so a + correct decode measurement skips almost all of them. This guards the two + decode bugs that random inputs don't expose: + * causal-offset ``kv_bound`` — without it the loop stops after the first + ``BLOCK_M`` tokens, so ``total`` would be a fraction of the cache. + * padding-row exclusion — without it the 127 padding rows veto every + tile and sparsity is 0%. + """ + num_heads, seq_k, head_dim = 4, 2048, 64 + block_n = 128 # the calibration kernel measures at 128x128 + q = torch.ones(1, num_heads, 1, head_dim, device="cuda", dtype=torch.float16) + k = torch.zeros(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + k[:, :, 0] = 20.0 # attention sink dominates every query + v = torch.randn(1, num_heads, seq_k, head_dim, device="cuda", dtype=torch.float16) + + module = _calibration_module(THRESHOLD_TRIALS) + method = module._sparse_method_instance + triton_attention_forward(module, q, k, v, attention_mask=None, scaling=1.0 / head_dim**0.5) + + counters = method._hf_calibration_counters + total = int(counters[0, 0]) + # Full cache scanned (not truncated to the first block). + assert total == num_heads * (seq_k // block_n), total + sparsity = (counters[:, 1].float() / counters[:, 0].clamp(min=1)).tolist() + # Sink => the vast majority of tiles are negligible and skippable (not 0%). + assert max(sparsity) > 0.8, sparsity + # Skipped-tile fraction is non-decreasing as the threshold grows. + assert all(later >= earlier for earlier, later in itertools.pairwise(sparsity)), sparsity + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py index 4c029da751c..6369ad1fcd6 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_vllm_plugin.py @@ -36,7 +36,15 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE -from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ( + ModelOptSparseAttentionImpl, + collect_calibration_stats, + disable_calibration, + enable_calibration, + fit_calibration, + get_flashinfer_sparse_impl_cls, + patch_flashinfer_metadata_builder, +) if TRITON_KERNEL_AVAILABLE: from modelopt.torch.kernels.common.attention import attention as triton_attention @@ -374,3 +382,559 @@ def test_page_size_inferred_from_k_cache(self): output=output, ) torch.testing.assert_close(out_paged, out_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestModelOptSparseAttentionCalibration: + """Calibration mode: ``forward`` measures tile-skip stats via the paged kernel. + + Output must stay dense (calibration computes full attention) while per-request + records accumulate, ready to fit the exponential ``(a, b)`` model. + """ + + _TRIALS = [1e-4, 1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + + def test_prefill_calibration_records_and_dense_output(self): + """Prefill: output equals dense attention; one record per request.""" + lengths = [128, 256] + total = sum(lengths) + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + dtype = torch.float16 + + torch.manual_seed(0) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + + seq_lens = torch.tensor(lengths, device="cuda", dtype=torch.int32) + query_start_loc = torch.tensor([0, lengths[0], total], device="cuda", dtype=torch.int32) + kv_cache, block_table = _make_paged_cache( + k, v, query_start_loc[:2], seq_lens, num_kv_heads, head_dim, page_size + ) + attn_metadata = SimpleNamespace( + num_actual_tokens=total, + max_query_len=max(lengths), + max_seq_len=max(lengths), + query_start_loc=query_start_loc, + seq_lens=seq_lens, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + impl.sparse_kw = {} + enable_calibration([impl], self._TRIALS) + output = torch.empty_like(q) + out = impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + + # Output is dense per-request causal attention (full attention, no skip). + for i, length in enumerate(lengths): + start = int(query_start_loc[i].item()) + locs = torch.zeros(1, device="cuda", dtype=torch.int32) + lens = torch.tensor([length], device="cuda", dtype=torch.int32) + ref = triton_attention( + q[start : start + length], + k[start : start + length], + v[start : start + length], + locs, + lens, + length, + softmax_scale=1.0 / (head_dim**0.5), + is_causal=True, + ) + torch.testing.assert_close(out[start : start + length], ref, rtol=5e-3, atol=5e-3) + + stats = collect_calibration_stats([impl]) + assert len(stats["prefill"]) == len(lengths) + assert [r["sample_length"] for r in stats["prefill"]] == lengths + assert all(len(r["sparsity"]) == len(self._TRIALS) for r in stats["prefill"]) + assert not stats["decode"] + + def test_decode_calibration_records_decode_phase(self): + """Decode (seq_q=1, long cache): a decode record with real sparsity.""" + seq_k = 2048 + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + dtype = torch.float16 + + # A dominant sink at position 0 makes the distant cache skippable. + q = torch.ones(1, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.zeros(seq_k, num_kv_heads, head_dim, device="cuda", dtype=dtype) + k[0] = 20.0 + v = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=dtype) + kv_start = torch.zeros(1, device="cuda", dtype=torch.int32) + kv_len = torch.tensor([seq_k], device="cuda", dtype=torch.int32) + kv_cache, block_table = _make_paged_cache( + k, v, kv_start, kv_len, num_kv_heads, head_dim, page_size + ) + attn_metadata = SimpleNamespace( + num_actual_tokens=1, + max_query_len=1, + max_seq_len=seq_k, + query_start_loc=torch.tensor([0, 1], device="cuda", dtype=torch.int32), + seq_lens=kv_len, + block_table=block_table, + ) + + impl = _make_impl(num_heads, head_dim, num_kv_heads) + impl.sparse_kw = {} + enable_calibration([impl], self._TRIALS) + impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + stats = collect_calibration_stats([impl]) + assert not stats["prefill"] + assert len(stats["decode"]) == 1 + record = stats["decode"][0] + assert record["sample_length"] == seq_k + assert max(record["sparsity"]) > 0.8 # sink => most tiles skippable + + _FIT_TRIALS = [1e-4, 1e-3, 1e-2, 1e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + def test_fit_calibration_produces_exponential_params(self): + """Multiple lengths feed fit_calibration into a usable (a, b) per phase.""" + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + dtype = torch.float16 + impl = _make_impl(num_heads, head_dim, num_kv_heads) + impl.sparse_kw = {} + enable_calibration([impl], self._FIT_TRIALS) + + torch.manual_seed(0) + for length in (512, 1024, 2048): + # Localized attention (key norm decays with distance + a sink at 0) + # gives a graded skip sweep across thresholds, so the exponential + # fit's (10%, 90%) window has enough data points — unlike uniform + # random keys, which skip all-or-nothing. + q = torch.randn(length, num_heads, head_dim, device="cuda", dtype=dtype) + pos = torch.arange(length, device="cuda").float() + decay = torch.exp(-pos / (length * 0.15))[:, None, None] + k = (torch.randn(length, num_kv_heads, head_dim, device="cuda") * decay).to(dtype) + k[0] = 8.0 # sink + v = torch.randn(length, num_kv_heads, head_dim, device="cuda", dtype=dtype) + locs = torch.zeros(1, device="cuda", dtype=torch.int32) + lens = torch.tensor([length], device="cuda", dtype=torch.int32) + qsl = torch.tensor([0, length], device="cuda", dtype=torch.int32) + kv_cache, block_table = _make_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + attn_metadata = SimpleNamespace( + num_actual_tokens=length, + max_query_len=length, + max_seq_len=length, + query_start_loc=qsl, + seq_lens=lens, + block_table=block_table, + ) + impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=torch.empty_like(q), + ) + + disable_calibration([impl]) + params = fit_calibration([impl], self._FIT_TRIALS) + assert "prefill" in params + assert params["prefill"]["a"] > 0.0 + assert params["prefill"]["b"] > 0.0 + + +# FlashInfer backend import requires the `flashinfer` package; skip if absent. +flashinfer_backend = pytest.importorskip("vllm.v1.attention.backends.flashinfer") +FlashInferImpl = flashinfer_backend.FlashInferImpl + + +def _make_flashinfer_impl(num_heads, head_dim, num_kv_heads): + """Build a bare ModelOptSparseFlashInferImpl with the attrs forward() reads. + + Bypasses FlashInferImpl.__init__ (which needs a full vLLM config); the + calibration path only consults scale / num_kv_heads / head_size. + """ + cls = get_flashinfer_sparse_impl_cls() + impl = cls.__new__(cls) + impl.scale = 1.0 / (head_dim**0.5) + impl.num_kv_heads = num_kv_heads + impl.head_size = head_dim + return impl + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestModelOptSparseFlashInferCalibration: + """FlashInfer-backend calibration must match the FlashAttention path. + + Same logical K/V, different cache layout: FlashAttention is ``[2, ...]``; + FlashInfer is ``[num_blocks, 2, page_size, ...]`` with the dense paged + metadata stashed on the metadata object (not a `block_table` field). The + per-request measurement is shared, so the records must be identical. + """ + + _TRIALS = [1e-4, 1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + + @staticmethod + def _to_flashinfer_cache(fa_kv_cache): + """[2, num_blocks, page, H, D] -> FlashInfer [num_blocks, 2, page, H, D].""" + return fa_kv_cache.transpose(0, 1).contiguous() + + def test_matches_flashattention_prefill(self): + """FlashInfer and FlashAttention calibration give identical records + output.""" + lengths = [128, 256] + total = sum(lengths) + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + dtype = torch.float16 + + torch.manual_seed(0) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + seq_lens = torch.tensor(lengths, device="cuda", dtype=torch.int32) + qsl = torch.tensor([0, lengths[0], total], device="cuda", dtype=torch.int32) + fa_kv, block_table = _make_paged_cache( + k, v, qsl[:2], seq_lens, num_kv_heads, head_dim, page_size + ) + + # FlashAttention reference. + fa_impl = _make_impl(num_heads, head_dim, num_kv_heads) + fa_impl.sparse_kw = {} + enable_calibration([fa_impl], self._TRIALS) + out_fa = torch.empty_like(q) + fa_impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=fa_kv, + attn_metadata=SimpleNamespace( + num_actual_tokens=total, + max_query_len=max(lengths), + max_seq_len=max(lengths), + query_start_loc=qsl, + seq_lens=seq_lens, + block_table=block_table, + ), + output=out_fa, + ) + + # FlashInfer: [num_blocks, 2, page, H, D] cache + stashed dense metadata. + fi_impl = _make_flashinfer_impl(num_heads, head_dim, num_kv_heads) + enable_calibration([fi_impl], self._TRIALS) + out_fi = torch.empty_like(q) + fi_impl.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=self._to_flashinfer_cache(fa_kv), + attn_metadata=SimpleNamespace( + use_cascade=False, + _modelopt_block_table=block_table, + _modelopt_seq_lens=seq_lens, + _modelopt_query_start_loc=qsl, + _modelopt_num_actual_tokens=total, + ), + output=out_fi, + ) + + torch.testing.assert_close(out_fi, out_fa, rtol=5e-3, atol=5e-3) + assert collect_calibration_stats([fi_impl]) == collect_calibration_stats([fa_impl]) + + def test_sparse_inference_matches_flashattention(self): + """FlashInfer sparse *serving* output equals the FlashAttention path. + + Not calibrating: with ``sparse_kw`` set, the FlashInfer impl runs the + ModelOpt sparse Triton kernel over its ``[num_blocks, 2, page, ...]`` + cache and must produce the same output as ``ModelOptSparseAttentionImpl`` + on the same logical K/V. + """ + batch, seq_len = 2, 64 + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + total = batch * seq_len + dtype = torch.float16 + + torch.manual_seed(0) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + seq_lens = torch.tensor([seq_len, seq_len], device="cuda", dtype=torch.int32) + qsl = torch.tensor([0, seq_len, total], device="cuda", dtype=torch.int32) + fa_kv, block_table = _make_paged_cache( + k, v, qsl[:batch], seq_lens, num_kv_heads, head_dim, page_size + ) + + # FlashAttention sparse inference (reference, already == contiguous Triton). + fa = _make_impl(num_heads, head_dim, num_kv_heads) + fa.sparse_kw = _ACTIVE_PREFILL_SPARSE_KW + out_fa = torch.empty_like(q) + fa.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=fa_kv, + attn_metadata=SimpleNamespace( + num_actual_tokens=total, + max_query_len=seq_len, + max_seq_len=seq_len, + query_start_loc=qsl, + seq_lens=seq_lens, + block_table=block_table, + ), + output=out_fa, + ) + + # FlashInfer sparse inference (not calibrating; sparse_kw active). + fi = _make_flashinfer_impl(num_heads, head_dim, num_kv_heads) + fi.sparse_kw = _ACTIVE_PREFILL_SPARSE_KW + out_fi = torch.empty_like(q) + fi.forward( + layer=None, + query=q, + key=k, + value=v, + kv_cache=self._to_flashinfer_cache(fa_kv), + attn_metadata=SimpleNamespace( + use_cascade=False, + _modelopt_block_table=block_table, + _modelopt_seq_lens=seq_lens, + _modelopt_query_start_loc=qsl, + _modelopt_num_actual_tokens=total, + _modelopt_max_query_len=seq_len, + _modelopt_max_seq_len=seq_len, + ), + output=out_fi, + ) + + torch.testing.assert_close(out_fi, out_fa, rtol=5e-3, atol=5e-3) + + def test_decode_records_decode_phase(self): + """FlashInfer decode (seq_q=1, long cache) yields a decode record.""" + seq_k = 2048 + num_heads, num_kv_heads, head_dim, page_size = 4, 2, 64, 16 + dtype = torch.float16 + + q = torch.ones(1, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.zeros(seq_k, num_kv_heads, head_dim, device="cuda", dtype=dtype) + k[0] = 20.0 + v = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=dtype) + kv_start = torch.zeros(1, device="cuda", dtype=torch.int32) + kv_len = torch.tensor([seq_k], device="cuda", dtype=torch.int32) + fa_kv, block_table = _make_paged_cache( + k, v, kv_start, kv_len, num_kv_heads, head_dim, page_size + ) + + fi_impl = _make_flashinfer_impl(num_heads, head_dim, num_kv_heads) + enable_calibration([fi_impl], self._TRIALS) + fi_impl.forward( + layer=None, + query=q, + key=q, + value=q, + kv_cache=self._to_flashinfer_cache(fa_kv), + attn_metadata=SimpleNamespace( + use_cascade=False, + _modelopt_block_table=block_table, + _modelopt_seq_lens=kv_len, + _modelopt_query_start_loc=torch.tensor([0, 1], device="cuda", dtype=torch.int32), + _modelopt_num_actual_tokens=1, + ), + output=torch.empty_like(q), + ) + + stats = collect_calibration_stats([fi_impl]) + assert not stats["prefill"] + assert len(stats["decode"]) == 1 + assert stats["decode"][0]["sample_length"] == seq_k + assert max(stats["decode"][0]["sparsity"]) > 0.8 + + def test_delegates_to_flashinfer_when_not_calibrating(self, monkeypatch): + """Outside calibration (and for cascade), forward defers to native FlashInfer.""" + called = {} + + def fake_forward(self, layer, query, key, value, kv_cache, attn_metadata, *a, **k): + called["n"] = called.get("n", 0) + 1 + return query + + monkeypatch.setattr(FlashInferImpl, "forward", fake_forward) + impl = _make_flashinfer_impl(2, 64, 2) + q = torch.zeros(1, 2, 64, device="cuda", dtype=torch.float16) + + # (a) not calibrating -> delegate + impl.forward(None, q, q, q, torch.empty(0), SimpleNamespace(use_cascade=False)) + # (b) calibrating but cascade -> delegate + enable_calibration([impl], self._TRIALS) + impl.forward(None, q, q, q, torch.empty(0), SimpleNamespace(use_cascade=True)) + # (c) calibrating but builder not patched (no stashed metadata) -> delegate + impl.forward(None, q, q, q, torch.empty(0), SimpleNamespace(use_cascade=False)) + assert called["n"] == 3 + + def test_metadata_builder_patch_is_idempotent(self): + """patch_flashinfer_metadata_builder is safe to call repeatedly.""" + assert patch_flashinfer_metadata_builder() is True + assert patch_flashinfer_metadata_builder() is True + + _FIT_TRIALS = [1e-3, 3e-3, 1e-2, 3e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1, 9e-1] + + def _pytorch_sparsity(self, q4, k4, trials, scale): + """Per-threshold skipped-block fraction from PyTorch flash_skip_softmax.""" + from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import ( + FlashSkipSoftmax, + ) + + seq_q, seq_k = q4.shape[2], k4.shape[2] + scores = torch.matmul(q4, k4.transpose(-2, -1)) * scale + scores = scores.masked_fill( + torch.triu(torch.ones(seq_q, seq_k, device="cuda"), 1).bool()[None, None], + float("-inf"), + ) + method = FlashSkipSoftmax( + method_config={ + "thresholds": {"prefill": trials, "decode": trials}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + method._calibration_mode = True + method.thresholds = trials + return method.calc_correction_factor_and_p(scores, "prefill")[1]["sparsity"] + + def _flashinfer_sparsity(self, impl, q4, k4, v4, page_size): + """Per-threshold sparsity from ModelOptSparseFlashInferImpl.forward.""" + seq, num_heads, head_dim = q4.shape[2], q4.shape[1], q4.shape[3] + qf = q4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + kf = k4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + vf = v4.permute(0, 2, 1, 3).reshape(seq, num_heads, head_dim).contiguous() + locs = torch.zeros(1, device="cuda", dtype=torch.int32) + lens = torch.tensor([seq], device="cuda", dtype=torch.int32) + fa_kv, block_table = _make_paged_cache(kf, vf, locs, lens, num_heads, head_dim, page_size) + impl._calib_records = [] + impl.forward( + layer=None, + query=qf, + key=qf, + value=qf, + kv_cache=self._to_flashinfer_cache(fa_kv), + attn_metadata=SimpleNamespace( + use_cascade=False, + _modelopt_block_table=block_table, + _modelopt_seq_lens=lens, + _modelopt_query_start_loc=torch.tensor([0, seq], device="cuda", dtype=torch.int32), + _modelopt_num_actual_tokens=seq, + ), + output=torch.empty_like(qf), + ) + return impl._calib_records[-1]["sparsity"] + + def test_matches_pytorch_calibration(self): + """FlashInfer-layout calibration == PyTorch flash_skip_softmax, incl. fitted (a, b). + + Direct (not transitive) check: per-threshold sparsity measured through + ``ModelOptSparseFlashInferImpl.forward`` over a FlashInfer-layout cache, + and the exponential ``(a, b)`` fit, match the PyTorch path on identical + graded-attention inputs. + """ + from modelopt.torch.sparsity.attention_sparsity.calibration.calibrator import ( + DynamicThresholdCalibrator, + ) + + num_heads, head_dim, page_size = 4, 64, 16 # MHA (num_kv_heads == num_heads) + scale = 1.0 / (head_dim**0.5) + impl = _make_flashinfer_impl(num_heads, head_dim, num_heads) + enable_calibration([impl], self._FIT_TRIALS) + + pt_stats, fi_stats = [], [] + # Non-128-multiple lengths exercise the partial last block-row (padding + # rows), where flash_skip_softmax previously diverged from the kernel. + for seed, seq in enumerate((500, 1000, 2000)): + torch.manual_seed(seed) + # Localized-decay attention -> graded sparsity spanning the fit window. + q4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + pos = torch.arange(seq, device="cuda").float() + decay = torch.exp(-pos / (seq * 0.15))[None, None, :, None] + k4 = (torch.randn(1, num_heads, seq, head_dim, device="cuda") * decay).to(torch.float16) + k4[:, :, 0] = 8.0 # sink + v4 = torch.randn(1, num_heads, seq, head_dim, device="cuda", dtype=torch.float16) + + pt = self._pytorch_sparsity(q4, k4, self._FIT_TRIALS, scale) + fi = self._flashinfer_sparsity(impl, q4, k4, v4, page_size) + # Same 128x128 block skip logic. Differences are at most ~1 block out + # of hundreds, where an fp16 (Triton) vs fp64 (PyTorch) score lands on + # a threshold boundary — far below a real divergence. (Same tolerance + # as the kernel-level cross-check test_graded_prefill_matches_pytorch.) + for s_pt, s_fi in zip(pt, fi): + assert abs(s_pt - s_fi) <= 0.02, (seq, pt, fi) + pt_stats.append({"sparsity": pt, "sample_length": seq}) + fi_stats.append({"sparsity": fi, "sample_length": seq}) + + # The sweep must actually reach the fit-relevant (10%, 90%) window. + assert any(0.1 < s < 0.9 for record in pt_stats for s in record["sparsity"]) + + pt_fit = DynamicThresholdCalibrator(threshold_trials=self._FIT_TRIALS).calibrate_from_stats( + pt_stats, "prefill" + ) + fi_fit = DynamicThresholdCalibrator(threshold_trials=self._FIT_TRIALS).calibrate_from_stats( + fi_stats, "prefill" + ) + assert pt_fit and fi_fit, (pt_fit, fi_fit) + # Fitted (a, b) agree to well under 1% (measured ~6e-4); the loose bound + # absorbs the occasional boundary block flipping across hardware. + assert pt_fit["a"] == pytest.approx(fi_fit["a"], rel=5e-3) + assert pt_fit["b"] == pytest.approx(fi_fit["b"], rel=5e-3) + + +def test_collect_calibration_stats_averages_across_layers(): + """Aggregation averages sparsity across layers per sample, like the HF path. + + This is the alignment that makes vLLM calibration produce the same fitted + (a, b) as ``DynamicThresholdCalibrator._extract_calibration_stats``: one + record per sample with the cross-layer mean, not one record per (layer, + sample). Layers are aligned by record index (every layer sees the same + launches in the same order during calibration). + """ + from types import SimpleNamespace + + layer0 = SimpleNamespace( + _calib_records=[ + {"phase": "prefill", "sample_length": 100, "sparsity": [0.2, 0.4]}, + {"phase": "prefill", "sample_length": 200, "sparsity": [0.6, 0.8]}, + {"phase": "decode", "sample_length": 512, "sparsity": [0.5, 0.9]}, + ] + ) + layer1 = SimpleNamespace( + _calib_records=[ + {"phase": "prefill", "sample_length": 100, "sparsity": [0.4, 0.6]}, + {"phase": "prefill", "sample_length": 200, "sparsity": [0.8, 1.0]}, + {"phase": "decode", "sample_length": 512, "sparsity": [0.7, 0.5]}, + ] + ) + + stats = collect_calibration_stats([layer0, layer1]) + + # One record per sample (not per layer-sample): 2 prefill, 1 decode. + assert len(stats["prefill"]) == 2 + assert len(stats["decode"]) == 1 + # Sparsity is the per-threshold mean across the two layers. + assert stats["prefill"][0]["sparsity"] == pytest.approx([0.3, 0.5]) + assert stats["prefill"][0]["sample_length"] == 100 + assert stats["prefill"][1]["sparsity"] == pytest.approx([0.7, 0.9]) + assert stats["decode"][0]["sparsity"] == pytest.approx([0.6, 0.7]) + # A single layer is the identity (averaging over one layer). + assert collect_calibration_stats([layer0])["prefill"][0]["sparsity"] == pytest.approx( + [0.2, 0.4] + ) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py b/tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py index c7c6f56928b..1c7d27c404a 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_calibrator_fitting.py @@ -155,6 +155,71 @@ def test_calibrate_empty_stats_returns_empty(self): assert result == {} +class TestCalibrateFromStats: + """Backend-agnostic stats->fit path shared by HF and vLLM calibration.""" + + @staticmethod + def _synthetic_stats(trials, lengths, a_true, b_true): + """Build per-sample stats consistent with scale_factor = a * exp(b * S).""" + stats = [] + for length in lengths: + sparsity = [] + for t in trials: + sf = t * length + s = np.log(sf / a_true) / b_true if sf > 0 else 0.0 + sparsity.append(max(0.0, min(1.0, s))) + stats.append({"sparsity": sparsity, "sample_length": length}) + return stats + + def test_recovers_known_parameters(self): + """calibrate_from_stats must recover the (a, b) used to synthesize stats.""" + trials = [1e-4, 1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + a_true, b_true = 2.0, 8.0 + stats = self._synthetic_stats(trials, [256, 512, 1024, 2048, 4096], a_true, b_true) + + cal = DynamicThresholdCalibrator(threshold_trials=trials) + result = cal.calibrate_from_stats(stats, phase="decode") + + assert result["phase"] == "decode" + assert result["a"] == pytest.approx(a_true, rel=0.1) + assert result["b"] == pytest.approx(b_true, rel=0.1) + assert result["r_squared"] > 0.99 + + def test_matches_calibrate_forward_loop(self): + """calibrate() and calibrate_from_stats() agree given the same stats.""" + trials = [1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + stats = self._synthetic_stats(trials, [2048, 4096, 8192], a_true=0.1, b_true=5.0) + cal = DynamicThresholdCalibrator(threshold_trials=trials) + result = cal.calibrate_from_stats(stats, phase="prefill") + assert result["a"] == pytest.approx(0.1, rel=0.1) + assert result["b"] == pytest.approx(5.0, rel=0.1) + + def test_too_few_points_returns_empty(self): + """Fewer than 10 (scale_factor, sparsity) pairs yields an empty fit.""" + cal = DynamicThresholdCalibrator(threshold_trials=[0.1]) + result = cal.calibrate_from_stats([{"sparsity": [0.5], "sample_length": 1024}], "prefill") + assert result == {} + + def test_reports_per_sample_sparsity(self, capsys): + """The result exposes (and prints) each sample's measured sparsity.""" + trials = [1e-4, 1e-3, 1e-2, 1e-1, 5e-1, 9e-1] + lengths = [256, 512, 1024, 2048, 4096] + stats = self._synthetic_stats(trials, lengths, a_true=2.0, b_true=8.0) + + result = DynamicThresholdCalibrator(threshold_trials=trials).calibrate_from_stats( + stats, "prefill" + ) + + per_sample = result["per_sample_sparsity"] + assert len(per_sample) == len(lengths) + assert [s["sample_length"] for s in per_sample] == lengths + assert all(len(s["sparsity"]) == len(trials) for s in per_sample) + # Values match the input measurements (not the fitted average). + assert per_sample[0]["sparsity"] == pytest.approx(stats[0]["sparsity"]) + # And a per-sample table is printed. + assert "Per-sample prefill sparsity" in capsys.readouterr().out + + class TestSetThresholds: """Test _set_thresholds for both method types."""