Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions examples/vllm_serve/sparse_attn_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom vLLM workers for sparse attention.

``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with
``ModelOptSparseAttentionImpl`` on each Attention module after model loading.
The sparse impl uses the ModelOpt Triton kernel for both prefill and decode.

``SparseQuantWorker``: Applies quantization first, then sparse attention via
direct module walk (registry stacking does not work due to ``_DMRegistryCls``
forward identity check).

Usage:
SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\
meta-llama/Llama-3.1-8B --enforce-eager
"""

import fnmatch
import json
import os
from typing import Any

from fakequant_worker import disable_compilation
from vllm.attention.layer import Attention as VLLMAttention
from vllm.v1.worker.gpu_worker import Worker as BaseWorker

import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl

# ---------------------------------------------------------------------------
# Configuration from environment variables
# ---------------------------------------------------------------------------

sparse_config: dict[str, Any] = {
"sparse_cfg": os.environ.get("SPARSE_ATTN_CFG", None),
"calib_config_path": os.environ.get("SPARSE_CALIB_CONFIG_PATH", None),
}


# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------


_DEFAULT_SPARSE_CFG = {
"sparse_cfg": {
"*attn*": {
"sparsity_n": 2,
"sparsity_m": 4,
"num_sink_tokens": 0,
"dense_window_size": 1,
"enable": True,
},
"default": {"enable": False},
},
}


def _build_sparse_config(env_config: dict[str, Any]) -> dict | None:
"""Build sparse_cfg dict from env vars."""
cfg_name = env_config["sparse_cfg"]
if cfg_name is None:
return None
# Try looking up preset from mtsa, fall back to default
cfg = getattr(mtsa, cfg_name, None)
if cfg is not None:
return cfg
Comment on lines +78 to +80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate preset object type before returning it

If SPARSE_ATTN_CFG matches a non-dict symbol in modelopt.torch.sparsity.attention_sparsity, cfg is returned as-is and later consumed as a mapping, which can crash at runtime. Add an explicit isinstance(cfg, dict) guard and fail fast with a clear error.

Proposed fix
     cfg = getattr(mtsa, cfg_name, None)
     if cfg is not None:
-        return cfg
+        if not isinstance(cfg, dict):
+            raise ValueError(
+                f"Invalid sparse config preset '{cfg_name}': expected dict, got {type(cfg).__name__}."
+            )
+        return cfg
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 81 - 83, The code
currently returns whatever getattr(mtsa, cfg_name, None) yields (cfg) and may
hand back non-mapping objects; update the getter to validate that cfg is a dict
before returning (use isinstance(cfg, dict)) and otherwise raise a clear error
(e.g., ValueError) indicating that the symbol named by cfg_name in
modelopt.torch.sparsity.attention_sparsity must be a dict; reference the
getattr(mtsa, cfg_name, None) call and the cfg variable to locate the change.

# Use built-in default if name matches
if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"):
return _DEFAULT_SPARSE_CFG
raise ValueError(
f"Unknown sparse config: {cfg_name}. Set SPARSE_ATTN_CFG to 'default' or a valid preset name."
)


def _load_sparse_config(path: str) -> dict:
"""Load offline calibration config JSON."""
with open(path) as f:
calib_cfg = json.load(f)

sparse_cfg = {}
for pattern, layer_cfg in calib_cfg.items():
if pattern == "calibration":
sparse_cfg[pattern] = layer_cfg
continue
layer_cfg.setdefault("method", "triton_sparse_softmax")
layer_cfg.setdefault("backend", "triton")
layer_cfg.setdefault("enable", True)
sparse_cfg[pattern] = layer_cfg
sparse_cfg["default"] = {"enable": False}
Comment on lines +89 to +103
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Harden calibration JSON parsing with schema and bounds checks

SPARSE_CALIB_CONFIG_PATH is env-driven input, but Line 94–106 accepts arbitrary JSON structure and values without validation. This can propagate malformed sparsity params into kernel calls and create avoidable failure/DoS risk. Validate top-level/object types, allowed keys, and integer ranges before applying defaults.

Proposed fix
 def _load_sparse_config(path: str) -> dict:
     """Load offline calibration config JSON."""
-    with open(path) as f:
+    with open(path, encoding="utf-8") as f:
         calib_cfg = json.load(f)
+    if not isinstance(calib_cfg, dict):
+        raise ValueError("Calibration config must be a JSON object mapping pattern -> layer config.")
 
     sparse_cfg = {}
     for pattern, layer_cfg in calib_cfg.items():
+        if not isinstance(pattern, str):
+            raise ValueError("Calibration config keys must be strings.")
         if pattern == "calibration":
             sparse_cfg[pattern] = layer_cfg
             continue
+        if not isinstance(layer_cfg, dict):
+            raise ValueError(f"Layer config for pattern '{pattern}' must be an object.")
+        for int_key in ("sparsity_n", "sparsity_m", "num_sink_tokens", "dense_window_size"):
+            if int_key in layer_cfg and (
+                not isinstance(layer_cfg[int_key], int) or layer_cfg[int_key] < 0
+            ):
+                raise ValueError(f"Invalid '{int_key}' for pattern '{pattern}': {layer_cfg[int_key]!r}")
         layer_cfg.setdefault("method", "triton_sparse_softmax")
         layer_cfg.setdefault("backend", "triton")
         layer_cfg.setdefault("enable", True)
         sparse_cfg[pattern] = layer_cfg

As per coding guidelines, "Validate inputs and enforce limits to reduce resource-exhaustion/DoS risk (e.g., file sizes, expected schema/shape for sparse config/calibration JSON)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 92 - 106, The
_load_sparse_config function currently trusts arbitrary JSON from
SPARSE_CALIB_CONFIG_PATH; update it to validate the loaded object and each
layer_cfg: assert the top-level JSON is a dict, allowed top-level keys are
strings and either "calibration" or pattern names, and each layer_cfg is a dict
before applying defaults; enforce allowed keys (e.g., "method", "backend",
"enable", numeric sparsity params) and bounds for numeric fields (e.g., sparsity
percentages 0–100, integer layer indices >=0, and reasonable max limits) and
reject or clamp out-of-range values, raising a clear exception on invalid
schema; keep the existing defaults (method="triton_sparse_softmax",
backend="triton", enable=True) for valid entries and ensure
sparse_cfg["default"] = {"enable": False} remains set.


return {"sparse_cfg": sparse_cfg}


def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None:
"""Match a module name against sparse_cfg patterns."""
cfg = sparse_cfg.get("sparse_cfg", sparse_cfg)
for pattern, layer_cfg in cfg.items():
if pattern in ("default", "calibration"):
continue
if fnmatch.fnmatch(module_name, pattern):
return layer_cfg
return None


def _replace_attention_impl(worker, config: dict):
"""Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers.

Shared by SparseAttnWorker and SparseQuantWorker.
"""
if config["calib_config_path"]:
cfg = _load_sparse_config(config["calib_config_path"])
else:
cfg = _build_sparse_config(config)

if cfg is None:
return

model = worker.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()

patched = 0
for name, module in model.named_modules():
if not isinstance(module, VLLMAttention):
continue

# Match per-layer sparse config using name-based patterns
layer_cfg = _match_sparse_config(name, cfg)
if layer_cfg is None or not layer_cfg.get("enable", True):
continue

# Build per-layer sparse kwargs
sparse_kw = {}
sparsity_n = layer_cfg.get("sparsity_n", 0)
if sparsity_n > 0:
sparse_kw["sparsity_n"] = sparsity_n
sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0)
sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1)
threshold = layer_cfg.get("skip_softmax_threshold")
if threshold:
sparse_kw["skip_softmax_threshold"] = threshold

# Replace impl and store per-layer config
old_impl = module.impl
new_impl = ModelOptSparseAttentionImpl(
num_heads=old_impl.num_heads,
head_size=old_impl.head_size,
scale=old_impl.scale,
num_kv_heads=old_impl.num_kv_heads,
alibi_slopes=old_impl.alibi_slopes,
sliding_window=None, # overwritten below
kv_cache_dtype=old_impl.kv_cache_dtype,
logits_soft_cap=old_impl.logits_soft_cap,
attn_type=old_impl.attn_type,
kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name,
)
# Copy the already-transformed sliding_window tuple directly,
# since __init__ transforms int -> (sw-1, 0) and we can't reverse it.
new_impl.sliding_window = old_impl.sliding_window
# Store per-layer sparse kwargs on the impl for forward() to read
new_impl.sparse_kw = sparse_kw
module.impl = new_impl
patched += 1
print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers")


# ---------------------------------------------------------------------------
# Workers
# ---------------------------------------------------------------------------


class SparseAttnWorker(BaseWorker):
"""vLLM worker that uses the ModelOpt sparse attention backend.

Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each
Attention module right after model loading — before any forward pass
(including determine_available_memory profiling).
"""

def load_model(self, *args, **kwargs) -> None:
"""Load model, then replace attention impl with sparse variant."""
super().load_model(*args, **kwargs)
_replace_attention_impl(self, sparse_config)


class SparseQuantWorker(BaseWorker):
"""vLLM worker that applies quantization + sparse attention.

Quantization uses the standard registry-based ``mtq.quantize()``.
Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl
(same approach as SparseAttnWorker).
"""

def load_model(self, *args, **kwargs) -> None:
"""Load model, then replace attention impl with sparse variant."""
super().load_model(*args, **kwargs)
_replace_attention_impl(self, sparse_config)

def compile_or_warm_up_model(self) -> None:
"""Apply quantization before warm-up."""
from fakequant_worker import _fakequant_run_prolog_worker, quant_config

model = self.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()

with disable_compilation(model):
if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]:
_fakequant_run_prolog_worker(self)

super().compile_or_warm_up_model()
97 changes: 97 additions & 0 deletions examples/vllm_serve/vllm_serve_sparse_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Launch vLLM with sparse attention.

Usage:
SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\
meta-llama/Llama-3.1-8B --max-model-len 8192

Combined with quantization:
QUANT_CFG=INT8_SMOOTHQUANT_CFG SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT \\
python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B
"""

import os
import sys
from pathlib import Path

import uvloop
import vllm
from packaging import version
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser

vllm_version = version.parse(vllm.__version__)
if vllm_version <= version.parse("0.11.0"):
from vllm.utils import FlexibleArgumentParser
else:
from vllm.utils.argparse_utils import FlexibleArgumentParser

# Pass sparse attention env vars to ray workers (if supported by this vLLM version)
additional_env_vars = {
"SPARSE_ATTN_CFG",
"SPARSE_CALIB_CONFIG_PATH",
"QUANT_DATASET",
"QUANT_CALIB_SIZE",
"QUANT_CFG",
"AMAX_FILE_PATH",
"KV_QUANT_CFG",
}

try:
if vllm_version <= version.parse("0.11.0"):
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
else:
from vllm.v1.executor.ray_executor import RayDistributedExecutor
if hasattr(RayDistributedExecutor, "ADDITIONAL_ENV_VARS"):
RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
except ImportError:
pass # Ray not installed, single-node only


def main():
"""Launch vLLM with sparse attention worker."""
parser = FlexibleArgumentParser(description="vLLM model server with sparse attention")
parser.add_argument("model", type=str, help="The path or name of the model to serve")
parser = make_arg_parser(parser)

# Ensure workers can import our custom worker module
repo_root = str(Path(__file__).resolve().parent)
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"

# Select worker based on env vars
has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG")
has_sparse = os.environ.get("SPARSE_ATTN_CFG") or os.environ.get("SPARSE_CALIB_CONFIG_PATH")

if has_quant and has_sparse:
worker_cls = "sparse_attn_worker.SparseQuantWorker"
elif has_sparse:
worker_cls = "sparse_attn_worker.SparseAttnWorker"
else:
print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.")
worker_cls = None

if worker_cls:
parser.set_defaults(worker_cls=worker_cls)

args = parser.parse_args()
uvloop.run(run_server(args))


if __name__ == "__main__":
main()
Loading
Loading