Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions tensorrt_llm/_torch/attention_backend/fmha/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
import torch

from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs
from tensorrt_llm._torch.attention_backend.sparse.skip_softmax import (
SkipSoftmaxKernelParams,
SkipSoftmaxParams,
)
from tensorrt_llm.bindings.internal import thop

from .interface import Fmha
Expand Down Expand Up @@ -60,18 +56,11 @@ def forward(
forward_args: AttentionForwardArgs,
) -> None:
attn = self.attn
sparse_params = attn.sparse_params
skip_softmax_kernel_params = (
sparse_params.scheduler.get_kernel_params(timestep=forward_args.timestep)
if isinstance(sparse_params, SkipSoftmaxParams)
else SkipSoftmaxKernelParams()
)

# Every kwarg sources from ``attn`` / ``metadata`` / ``forward_args``
# (with ``forward_args.sparse_prediction`` for sparse-attn inputs),
# ``skip_softmax_kernel_params``, or a literal allowlisted in
# ``_THOP_LITERALS``. ``test_attention_op_sync.py`` enforces this
# statically.
# or a literal allowlisted in ``_THOP_LITERALS``.
# ``test_attention_op_sync.py`` enforces this statically.
thop.attention(
q=q,
k=k,
Expand Down Expand Up @@ -146,6 +135,8 @@ def forward(
cross_kv=forward_args.cross_kv,
relative_attention_bias=forward_args.relative_attention_bias,
relative_attention_max_distance=forward_args.relative_attention_max_distance,
skip_softmax_threshold_scale_factor_prefill=forward_args.skip_softmax_kernel_params.threshold_scale_factor_prefill,
skip_softmax_threshold_scale_factor_decode=forward_args.skip_softmax_kernel_params.threshold_scale_factor_decode,
# --- Module config (TrtllmAttention) ---
rotary_inv_freq=attn.rotary_inv_freq,
rotary_cos_sin=attn.rotary_cos_sin,
Expand Down Expand Up @@ -173,8 +164,6 @@ def forward(
v_head_dim=attn.v_head_dim,
rope_append=attn.rope_append,
attention_chunk_size=attn.attention_chunk_size,
skip_softmax_threshold_scale_factor_prefill=skip_softmax_kernel_params.threshold_scale_factor_prefill,
skip_softmax_threshold_scale_factor_decode=skip_softmax_kernel_params.threshold_scale_factor_decode,
skip_softmax_stat=attn.skip_softmax_stat,
# --- Sparse-specific (AttentionForwardArgs.sparse_prediction) ---
sparse_kv_indices=forward_args.sparse_prediction.sparse_kv_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def _is_supported_with_reason(
) -> Tuple[bool, str]:
is_mla_enable = attn.is_mla_enable
sparse_params = attn.sparse_params
has_skip_softmax = getattr(sparse_params, "algorithm", None) == "skip_softmax"
has_skip_softmax = sparse_params is not None and sparse_params.algorithm == "skip_softmax"
has_sparse_attention = sparse_params is not None and not has_skip_softmax
if (
fwd.sage_attn_num_elts_per_blk_q > 0
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..pyexecutor.mamba_cache_manager import BaseMambaCacheManager
from ..pyexecutor.resource_manager import KVCacheManager
from ..utils import get_model_extra_attrs
from .sparse.params import SparseMetadataParams
from .sparse.params import SkipSoftmaxKernelParams, SparseMetadataParams

try:
# Transformers v5
Expand Down Expand Up @@ -890,6 +890,8 @@ class AttentionForwardArgs:

sparse_prediction: SparsePrediction = field(
default_factory=SparsePrediction)
skip_softmax_kernel_params: SkipSoftmaxKernelParams = field(
default_factory=SkipSoftmaxKernelParams)

@property
def mask_type(self) -> int:
Expand Down
18 changes: 17 additions & 1 deletion tensorrt_llm/_torch/attention_backend/sparse/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
ModelConfig. They lower through ``to_sparse_params()`` for per-backend runtime
state and ``to_sparse_metadata_params()`` for metadata allocation/update state.
Concrete parameter classes live with their backend implementations.
Concrete sparse algorithm params live with their backend implementations;
shared kernel-facing carriers live here when they are part of the generic
attention-forward contract.
"""

from dataclasses import dataclass


class SparseParams:
"""Base for per AttentionBackend instance sparse runtime parameters.
Expand All @@ -42,3 +46,15 @@ class SparseMetadataParams:
metadata owns batch/runtime buffers rather than per-layer
``AttentionBackend`` behavior.
"""


@dataclass
class SkipSoftmaxKernelParams:
"""Skip-softmax thresholds passed to attention backend kernels."""

# The kernel divides this by the context length to get the skip threshold;
# zero turns skip-softmax off.
threshold_scale_factor_prefill: float = 0.0
# Only autoregressive (LLM) decoding has a decode phase; diffusion and
# visual generation leave this at zero.
threshold_scale_factor_decode: float = 0.0
24 changes: 3 additions & 21 deletions tensorrt_llm/_torch/attention_backend/sparse/skip_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
The kernel consumes a scalar ``threshold_scale_factor`` (combined with
the sequence length at runtime) to decide which KV blocks to skip.
This module owns the calibration side that produces that scalar from a
semantic ``target_sparsity`` via a formula shipped with the checkpoint
(the formula model and checkpoint-parsing helpers), plus the
kernel-facing :class:`SkipSoftmaxKernelParams` carrier that the shared
``TrtllmAttention`` reads. All of it is shared by the LLM and VisualGen
pipelines.
semantic ``target_sparsity`` via a formula shipped with the checkpoint.
The helpers are shared by the LLM and VisualGen pipelines.
"""

from dataclasses import dataclass, field
Expand All @@ -34,7 +31,7 @@

from tensorrt_llm.llmapi.utils import StrictBaseModel

from .params import SparseParams
from .params import SkipSoftmaxKernelParams, SparseParams

_RESERVED_FORMULA_KEYS = frozenset({"formula", "target_sparsity"})
_SKIP_SOFTMAX_ALGORITHMS = frozenset({"skip_softmax", "softmax_skip"})
Expand Down Expand Up @@ -272,21 +269,6 @@ def skip_softmax_formula_from_ckpt_sparse_attention_config(
return None


@dataclass
class SkipSoftmaxKernelParams:
"""Skip-softmax thresholds passed to attention backend kernels.

Produced by ``SkipSoftmaxScheduler.get_kernel_params()``.
"""

# The kernel divides this by the context length to get the skip threshold;
# zero turns skip-softmax off.
threshold_scale_factor_prefill: float = 0.0
# Only autoregressive (LLM) decoding has a decode phase; diffusion and
# visual generation leave this at zero.
threshold_scale_factor_decode: float = 0.0


class SkipSoftmaxScheduler:
"""Layer runtime scheduler for skip-softmax kernel thresholds."""

Expand Down
35 changes: 19 additions & 16 deletions tensorrt_llm/_torch/attention_backend/sparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Type, Union

if TYPE_CHECKING:
from tensorrt_llm._torch.attention_backend.interface import AttentionBackend
from tensorrt_llm.llmapi.llm_args import \
SparseAttentionConfig as LlmSparseAttentionConfig
from tensorrt_llm.visual_gen.args import \
SparseAttentionConfig as VisualGenSparseAttentionConfig

from .params import SparseParams

SparseAttentionConfig = Union[LlmSparseAttentionConfig,
VisualGenSparseAttentionConfig]

Expand All @@ -16,7 +19,7 @@


def get_sparse_attn_kv_cache_manager(
sparse_attention_config: "SparseAttentionConfig"):
sparse_attention_config: "SparseAttentionConfig") -> type:
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager

from .deepseek_v4 import DeepseekV4CacheManager
Expand All @@ -40,36 +43,36 @@ def get_sparse_attn_kv_cache_manager(


def get_vanilla_sparse_attn_attention_backend(
sparse_attention_config: "SparseAttentionConfig"):
sparse_params: "SparseParams") -> Type["AttentionBackend"]:
from .minimax_m3 import get_minimax_m3_attention_backend_cls
from .rocket import RocketVanillaAttention
if sparse_attention_config.algorithm == "rocket":
if sparse_params.algorithm == "rocket":
return RocketVanillaAttention
elif sparse_attention_config.algorithm == "minimax_m3":
elif sparse_params.algorithm == "minimax_m3":
return get_minimax_m3_attention_backend_cls()
else:
raise ValueError(
f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attention_config.algorithm}"
f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_params.algorithm}"
)


def get_trtllm_sparse_attn_attention_backend(
sparse_attention_config: "SparseAttentionConfig"):
sparse_params: "SparseParams") -> Type["AttentionBackend"]:
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttention

from .deepseek_v4 import DeepseekV4TrtllmAttention
from .dsa import DSATrtllmAttention
from .minimax_m3 import get_minimax_m3_attention_backend_cls
from .rocket import RocketTrtllmAttention
if sparse_attention_config.algorithm == "rocket":
if sparse_params.algorithm == "rocket":
return RocketTrtllmAttention
elif sparse_attention_config.algorithm == "dsa":
elif sparse_params.algorithm == "dsa":
return DSATrtllmAttention
elif sparse_attention_config.algorithm == "deepseek_v4":
elif sparse_params.algorithm == "deepseek_v4":
return DeepseekV4TrtllmAttention
elif sparse_attention_config.algorithm == "skip_softmax":
elif sparse_params.algorithm == "skip_softmax":
return TrtllmAttention
elif sparse_attention_config.algorithm == "minimax_m3":
elif sparse_params.algorithm == "minimax_m3":
# The MiniMax-M3 sparse algorithm runs in Python through the
# model-layer override; this backend exists so the standard
# `create_attention(...)` dispatch in `Attention.__init__`
Expand All @@ -78,15 +81,15 @@ def get_trtllm_sparse_attn_attention_backend(
return get_minimax_m3_attention_backend_cls()
else:
raise ValueError(
f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attention_config.algorithm}"
f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_params.algorithm}"
)


def get_flashinfer_sparse_attn_attention_backend(
sparse_attention_config: "SparseAttentionConfig"):
sparse_params: "SparseParams") -> Type["AttentionBackend"]:
from .minimax_m3 import get_minimax_m3_attention_backend_cls
if sparse_attention_config.algorithm == "minimax_m3":
if sparse_params.algorithm == "minimax_m3":
return get_minimax_m3_attention_backend_cls()
raise ValueError(
f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attention_config.algorithm}"
f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_params.algorithm}"
)
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,12 @@ def forward(
# Cross-attention uses the THOP path; the trtllm-gen backend API does
# not carry encoder K/V tensors yet.

# cpp/tensorrt_llm/thop/attentionOp.cpp enables mFP8ContextFMHA for an
# FP8 KV cache only when use_paged_context_fmha is true. Force paged
# context so QKV preprocessing and context FMHA use the FP8 path.
if self.has_fp8_kv_cache:
metadata.use_paged_context_fmha = True

# SM90 forces `use_paged_context_fmha` on for correctness
# (https://nvbugs/5624818).
if get_sm_version() == 90:
Expand Down Expand Up @@ -1703,6 +1709,12 @@ def forward(
if forward_args.kv_scale_quant_orig is None:
forward_args.kv_scale_quant_orig = self.kv_scale_quant_orig

sparse_params = self.sparse_params
if isinstance(sparse_params, SkipSoftmaxParams):
forward_args.skip_softmax_kernel_params = (
sparse_params.scheduler.get_kernel_params(
timestep=forward_args.timestep))

# max_context_q_len_override is only set when encoder CUDA graphs are enabled.
if metadata.max_context_q_len_override is not None:
assert metadata.is_cuda_graph
Expand Down
38 changes: 10 additions & 28 deletions tensorrt_llm/_torch/attention_backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Sequence, Type, Union
from typing import Optional, Sequence, Type

import torch

Expand All @@ -14,36 +14,24 @@
from .trtllm import TrtllmAttention
from .vanilla import VanillaAttention

if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import \
SparseAttentionConfig as LlmSparseAttentionConfig
from tensorrt_llm.visual_gen.args import \
SparseAttentionConfig as VisualGenSparseAttentionConfig

SparseAttentionConfig = Union[LlmSparseAttentionConfig,
VisualGenSparseAttentionConfig]


def get_attention_backend(
backend_name: str,
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
sparse_params: Optional[SparseParams] = None,
) -> Type[AttentionBackend]:
backend_name = backend_name.upper()
if backend_name == "VANILLA":
if sparse_attention_config is not None:
return get_vanilla_sparse_attn_attention_backend(
sparse_attention_config)
if sparse_params is not None:
return get_vanilla_sparse_attn_attention_backend(sparse_params)
return VanillaAttention
elif backend_name == "TRTLLM":
if sparse_attention_config is not None:
return get_trtllm_sparse_attn_attention_backend(
sparse_attention_config)
if sparse_params is not None:
return get_trtllm_sparse_attn_attention_backend(sparse_params)
return TrtllmAttention
elif backend_name == "FLASHINFER" and IS_FLASHINFER_AVAILABLE:
from .flashinfer import FlashInferAttention
if sparse_attention_config is not None:
return get_flashinfer_sparse_attn_attention_backend(
sparse_attention_config)
if sparse_params is not None:
return get_flashinfer_sparse_attn_attention_backend(sparse_params)
return FlashInferAttention
elif backend_name == "FLASHINFER_STAR_ATTENTION" and IS_FLASHINFER_AVAILABLE:
from .star_flashinfer import StarAttention
Expand Down Expand Up @@ -73,19 +61,14 @@ def create_attention(
predicted_tokens_per_seq: Optional[int] = 1,
skip_create_weights_in_init: bool = False,
attention_chunk_size: Optional[int] = None,
attn_cls: Optional[Type[AttentionBackend]] = None,
sparse_params: Optional[SparseParams] = None,
dtype: Optional[torch.dtype] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
):
if attention_chunk_size is not None and backend_name.upper() != "TRTLLM":
raise ValueError(
f"Backend {backend_name} does not support chunked attention.")
if sparse_params is not None and attn_cls is None:
raise ValueError("attn_cls is required when sparse_params is set.")

if attn_cls is None:
attn_cls = get_attention_backend(backend_name)
attn_cls = get_attention_backend(backend_name, sparse_params=sparse_params)

if is_mla_enable:
assert attn_cls.support_mla(
Expand Down Expand Up @@ -114,9 +97,8 @@ def create_attention(
attention_chunk_size=attention_chunk_size,
dtype=dtype,
aux_stream=aux_stream,
sparse_params=sparse_params,
)
if sparse_params is not None:
kwargs["sparse_params"] = sparse_params

return attn_cls(
layer_idx,
Expand Down
21 changes: 3 additions & 18 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@
from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding


def _lower_sparse_attention_params(sparse_attn_cfg,
pretrained_config=None,
layer_idx: Optional[int] = None):
if getattr(sparse_attn_cfg, "algorithm", None) == "deepseek_v4":
from tensorrt_llm._torch.attention_backend.sparse.deepseek_v4 import \
make_deepseek_v4_sparse_params

return make_deepseek_v4_sparse_params(
sparse_attn_cfg, pretrained_config=pretrained_config)
return sparse_attn_cfg.to_sparse_params(pretrained_config=pretrained_config,
layer_idx=layer_idx)


def extract_extra_attrs(layer_idx: str, attn_type: str):
assert attn_type in ["mla", "attn"], "Invalid attention type"
extra_attrs = get_model_extra_attrs()
Expand Down Expand Up @@ -609,13 +596,12 @@ def __init__(
self.attn_backend = config.attn_backend

sparse_attn_cfg = config.sparse_attention_config
sparse_params = (_lower_sparse_attention_params(
sparse_attn_cfg,
sparse_params = (sparse_attn_cfg.to_sparse_params(
pretrained_config=config.pretrained_config,
layer_idx=self.layer_idx) if sparse_attn_cfg is not None else None)

attn_cls = get_attention_backend(
self.attn_backend, sparse_attention_config=sparse_attn_cfg)
attn_cls = get_attention_backend(self.attn_backend,
sparse_params=sparse_params)

self.is_marlin_enabled: bool = is_nvfp4_marlin_enabled()

Expand Down Expand Up @@ -686,7 +672,6 @@ def __init__(
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
attention_chunk_size=self.attention_chunk_size,
attn_cls=attn_cls,
sparse_params=sparse_params,
)

Expand Down
Loading
Loading