diff --git a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py index c13a3298daf..4413adaa0d5 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/fallback.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/fallback.py @@ -18,10 +18,6 @@ import torch from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs -from tensorrt_llm._torch.attention_backend.sparse.skip_softmax import ( - SkipSoftmaxKernelParams, - SkipSoftmaxParams, -) from tensorrt_llm.bindings.internal import thop from .interface import Fmha @@ -60,18 +56,11 @@ def forward( forward_args: AttentionForwardArgs, ) -> None: attn = self.attn - sparse_params = attn.sparse_params - skip_softmax_kernel_params = ( - sparse_params.scheduler.get_kernel_params(timestep=forward_args.timestep) - if isinstance(sparse_params, SkipSoftmaxParams) - else SkipSoftmaxKernelParams() - ) # Every kwarg sources from ``attn`` / ``metadata`` / ``forward_args`` # (with ``forward_args.sparse_prediction`` for sparse-attn inputs), - # ``skip_softmax_kernel_params``, or a literal allowlisted in - # ``_THOP_LITERALS``. ``test_attention_op_sync.py`` enforces this - # statically. + # or a literal allowlisted in ``_THOP_LITERALS``. + # ``test_attention_op_sync.py`` enforces this statically. thop.attention( q=q, k=k, @@ -146,6 +135,8 @@ def forward( cross_kv=forward_args.cross_kv, relative_attention_bias=forward_args.relative_attention_bias, relative_attention_max_distance=forward_args.relative_attention_max_distance, + skip_softmax_threshold_scale_factor_prefill=forward_args.skip_softmax_kernel_params.threshold_scale_factor_prefill, + skip_softmax_threshold_scale_factor_decode=forward_args.skip_softmax_kernel_params.threshold_scale_factor_decode, # --- Module config (TrtllmAttention) --- rotary_inv_freq=attn.rotary_inv_freq, rotary_cos_sin=attn.rotary_cos_sin, @@ -173,8 +164,6 @@ def forward( v_head_dim=attn.v_head_dim, rope_append=attn.rope_append, attention_chunk_size=attn.attention_chunk_size, - skip_softmax_threshold_scale_factor_prefill=skip_softmax_kernel_params.threshold_scale_factor_prefill, - skip_softmax_threshold_scale_factor_decode=skip_softmax_kernel_params.threshold_scale_factor_decode, skip_softmax_stat=attn.skip_softmax_stat, # --- Sparse-specific (AttentionForwardArgs.sparse_prediction) --- sparse_kv_indices=forward_args.sparse_prediction.sparse_kv_indices, diff --git a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py index b39173f8b67..795e1766edc 100644 --- a/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/fmha/flashinfer_trtllm_gen.py @@ -608,7 +608,7 @@ def _is_supported_with_reason( ) -> Tuple[bool, str]: is_mla_enable = attn.is_mla_enable sparse_params = attn.sparse_params - has_skip_softmax = getattr(sparse_params, "algorithm", None) == "skip_softmax" + has_skip_softmax = sparse_params is not None and sparse_params.algorithm == "skip_softmax" has_sparse_attention = sparse_params is not None and not has_skip_softmax if ( fwd.sage_attn_num_elts_per_blk_q > 0 diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index e91348f9a4e..073d2fefcde 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -26,7 +26,7 @@ from ..pyexecutor.mamba_cache_manager import BaseMambaCacheManager from ..pyexecutor.resource_manager import KVCacheManager from ..utils import get_model_extra_attrs -from .sparse.params import SparseMetadataParams +from .sparse.params import SkipSoftmaxKernelParams, SparseMetadataParams try: # Transformers v5 @@ -890,6 +890,8 @@ class AttentionForwardArgs: sparse_prediction: SparsePrediction = field( default_factory=SparsePrediction) + skip_softmax_kernel_params: SkipSoftmaxKernelParams = field( + default_factory=SkipSoftmaxKernelParams) @property def mask_type(self) -> int: diff --git a/tensorrt_llm/_torch/attention_backend/sparse/params.py b/tensorrt_llm/_torch/attention_backend/sparse/params.py index 76e2d1a1278..cafd26e53b1 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/params.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/params.py @@ -18,9 +18,13 @@ ModelConfig. They lower through ``to_sparse_params()`` for per-backend runtime state and ``to_sparse_metadata_params()`` for metadata allocation/update state. -Concrete parameter classes live with their backend implementations. +Concrete sparse algorithm params live with their backend implementations; +shared kernel-facing carriers live here when they are part of the generic +attention-forward contract. """ +from dataclasses import dataclass + class SparseParams: """Base for per AttentionBackend instance sparse runtime parameters. @@ -42,3 +46,15 @@ class SparseMetadataParams: metadata owns batch/runtime buffers rather than per-layer ``AttentionBackend`` behavior. """ + + +@dataclass +class SkipSoftmaxKernelParams: + """Skip-softmax thresholds passed to attention backend kernels.""" + + # The kernel divides this by the context length to get the skip threshold; + # zero turns skip-softmax off. + threshold_scale_factor_prefill: float = 0.0 + # Only autoregressive (LLM) decoding has a decode phase; diffusion and + # visual generation leave this at zero. + threshold_scale_factor_decode: float = 0.0 diff --git a/tensorrt_llm/_torch/attention_backend/sparse/skip_softmax.py b/tensorrt_llm/_torch/attention_backend/sparse/skip_softmax.py index 49b4a4eb0ed..1290632b6a5 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/skip_softmax.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/skip_softmax.py @@ -17,11 +17,8 @@ The kernel consumes a scalar ``threshold_scale_factor`` (combined with the sequence length at runtime) to decide which KV blocks to skip. This module owns the calibration side that produces that scalar from a -semantic ``target_sparsity`` via a formula shipped with the checkpoint -(the formula model and checkpoint-parsing helpers), plus the -kernel-facing :class:`SkipSoftmaxKernelParams` carrier that the shared -``TrtllmAttention`` reads. All of it is shared by the LLM and VisualGen -pipelines. +semantic ``target_sparsity`` via a formula shipped with the checkpoint. +The helpers are shared by the LLM and VisualGen pipelines. """ from dataclasses import dataclass, field @@ -34,7 +31,7 @@ from tensorrt_llm.llmapi.utils import StrictBaseModel -from .params import SparseParams +from .params import SkipSoftmaxKernelParams, SparseParams _RESERVED_FORMULA_KEYS = frozenset({"formula", "target_sparsity"}) _SKIP_SOFTMAX_ALGORITHMS = frozenset({"skip_softmax", "softmax_skip"}) @@ -272,21 +269,6 @@ def skip_softmax_formula_from_ckpt_sparse_attention_config( return None -@dataclass -class SkipSoftmaxKernelParams: - """Skip-softmax thresholds passed to attention backend kernels. - - Produced by ``SkipSoftmaxScheduler.get_kernel_params()``. - """ - - # The kernel divides this by the context length to get the skip threshold; - # zero turns skip-softmax off. - threshold_scale_factor_prefill: float = 0.0 - # Only autoregressive (LLM) decoding has a decode phase; diffusion and - # visual generation leave this at zero. - threshold_scale_factor_decode: float = 0.0 - - class SkipSoftmaxScheduler: """Layer runtime scheduler for skip-softmax kernel thresholds.""" diff --git a/tensorrt_llm/_torch/attention_backend/sparse/utils.py b/tensorrt_llm/_torch/attention_backend/sparse/utils.py index 12e9b146c4e..f438d998939 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/utils.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/utils.py @@ -1,11 +1,14 @@ -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Type, Union if TYPE_CHECKING: + from tensorrt_llm._torch.attention_backend.interface import AttentionBackend from tensorrt_llm.llmapi.llm_args import \ SparseAttentionConfig as LlmSparseAttentionConfig from tensorrt_llm.visual_gen.args import \ SparseAttentionConfig as VisualGenSparseAttentionConfig + from .params import SparseParams + SparseAttentionConfig = Union[LlmSparseAttentionConfig, VisualGenSparseAttentionConfig] @@ -16,7 +19,7 @@ def get_sparse_attn_kv_cache_manager( - sparse_attention_config: "SparseAttentionConfig"): + sparse_attention_config: "SparseAttentionConfig") -> type: from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from .deepseek_v4 import DeepseekV4CacheManager @@ -40,36 +43,36 @@ def get_sparse_attn_kv_cache_manager( def get_vanilla_sparse_attn_attention_backend( - sparse_attention_config: "SparseAttentionConfig"): + sparse_params: "SparseParams") -> Type["AttentionBackend"]: from .minimax_m3 import get_minimax_m3_attention_backend_cls from .rocket import RocketVanillaAttention - if sparse_attention_config.algorithm == "rocket": + if sparse_params.algorithm == "rocket": return RocketVanillaAttention - elif sparse_attention_config.algorithm == "minimax_m3": + elif sparse_params.algorithm == "minimax_m3": return get_minimax_m3_attention_backend_cls() else: raise ValueError( - f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attention_config.algorithm}" + f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_params.algorithm}" ) def get_trtllm_sparse_attn_attention_backend( - sparse_attention_config: "SparseAttentionConfig"): + sparse_params: "SparseParams") -> Type["AttentionBackend"]: from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttention from .deepseek_v4 import DeepseekV4TrtllmAttention from .dsa import DSATrtllmAttention from .minimax_m3 import get_minimax_m3_attention_backend_cls from .rocket import RocketTrtllmAttention - if sparse_attention_config.algorithm == "rocket": + if sparse_params.algorithm == "rocket": return RocketTrtllmAttention - elif sparse_attention_config.algorithm == "dsa": + elif sparse_params.algorithm == "dsa": return DSATrtllmAttention - elif sparse_attention_config.algorithm == "deepseek_v4": + elif sparse_params.algorithm == "deepseek_v4": return DeepseekV4TrtllmAttention - elif sparse_attention_config.algorithm == "skip_softmax": + elif sparse_params.algorithm == "skip_softmax": return TrtllmAttention - elif sparse_attention_config.algorithm == "minimax_m3": + elif sparse_params.algorithm == "minimax_m3": # The MiniMax-M3 sparse algorithm runs in Python through the # model-layer override; this backend exists so the standard # `create_attention(...)` dispatch in `Attention.__init__` @@ -78,15 +81,15 @@ def get_trtllm_sparse_attn_attention_backend( return get_minimax_m3_attention_backend_cls() else: raise ValueError( - f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attention_config.algorithm}" + f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_params.algorithm}" ) def get_flashinfer_sparse_attn_attention_backend( - sparse_attention_config: "SparseAttentionConfig"): + sparse_params: "SparseParams") -> Type["AttentionBackend"]: from .minimax_m3 import get_minimax_m3_attention_backend_cls - if sparse_attention_config.algorithm == "minimax_m3": + if sparse_params.algorithm == "minimax_m3": return get_minimax_m3_attention_backend_cls() raise ValueError( - f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attention_config.algorithm}" + f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_params.algorithm}" ) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 685472d6681..0f00dbf34a3 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1471,6 +1471,12 @@ def forward( # Cross-attention uses the THOP path; the trtllm-gen backend API does # not carry encoder K/V tensors yet. + # cpp/tensorrt_llm/thop/attentionOp.cpp enables mFP8ContextFMHA for an + # FP8 KV cache only when use_paged_context_fmha is true. Force paged + # context so QKV preprocessing and context FMHA use the FP8 path. + if self.has_fp8_kv_cache: + metadata.use_paged_context_fmha = True + # SM90 forces `use_paged_context_fmha` on for correctness # (https://nvbugs/5624818). if get_sm_version() == 90: @@ -1703,6 +1709,12 @@ def forward( if forward_args.kv_scale_quant_orig is None: forward_args.kv_scale_quant_orig = self.kv_scale_quant_orig + sparse_params = self.sparse_params + if isinstance(sparse_params, SkipSoftmaxParams): + forward_args.skip_softmax_kernel_params = ( + sparse_params.scheduler.get_kernel_params( + timestep=forward_args.timestep)) + # max_context_q_len_override is only set when encoder CUDA graphs are enabled. if metadata.max_context_q_len_override is not None: assert metadata.is_cuda_graph diff --git a/tensorrt_llm/_torch/attention_backend/utils.py b/tensorrt_llm/_torch/attention_backend/utils.py index d3487088381..ef83c99159e 100644 --- a/tensorrt_llm/_torch/attention_backend/utils.py +++ b/tensorrt_llm/_torch/attention_backend/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Sequence, Type, Union +from typing import Optional, Sequence, Type import torch @@ -14,36 +14,24 @@ from .trtllm import TrtllmAttention from .vanilla import VanillaAttention -if TYPE_CHECKING: - from tensorrt_llm.llmapi.llm_args import \ - SparseAttentionConfig as LlmSparseAttentionConfig - from tensorrt_llm.visual_gen.args import \ - SparseAttentionConfig as VisualGenSparseAttentionConfig - - SparseAttentionConfig = Union[LlmSparseAttentionConfig, - VisualGenSparseAttentionConfig] - def get_attention_backend( backend_name: str, - sparse_attention_config: Optional["SparseAttentionConfig"] = None, + sparse_params: Optional[SparseParams] = None, ) -> Type[AttentionBackend]: backend_name = backend_name.upper() if backend_name == "VANILLA": - if sparse_attention_config is not None: - return get_vanilla_sparse_attn_attention_backend( - sparse_attention_config) + if sparse_params is not None: + return get_vanilla_sparse_attn_attention_backend(sparse_params) return VanillaAttention elif backend_name == "TRTLLM": - if sparse_attention_config is not None: - return get_trtllm_sparse_attn_attention_backend( - sparse_attention_config) + if sparse_params is not None: + return get_trtllm_sparse_attn_attention_backend(sparse_params) return TrtllmAttention elif backend_name == "FLASHINFER" and IS_FLASHINFER_AVAILABLE: from .flashinfer import FlashInferAttention - if sparse_attention_config is not None: - return get_flashinfer_sparse_attn_attention_backend( - sparse_attention_config) + if sparse_params is not None: + return get_flashinfer_sparse_attn_attention_backend(sparse_params) return FlashInferAttention elif backend_name == "FLASHINFER_STAR_ATTENTION" and IS_FLASHINFER_AVAILABLE: from .star_flashinfer import StarAttention @@ -73,7 +61,6 @@ def create_attention( predicted_tokens_per_seq: Optional[int] = 1, skip_create_weights_in_init: bool = False, attention_chunk_size: Optional[int] = None, - attn_cls: Optional[Type[AttentionBackend]] = None, sparse_params: Optional[SparseParams] = None, dtype: Optional[torch.dtype] = None, aux_stream: Optional[torch.cuda.Stream] = None, @@ -81,11 +68,7 @@ def create_attention( if attention_chunk_size is not None and backend_name.upper() != "TRTLLM": raise ValueError( f"Backend {backend_name} does not support chunked attention.") - if sparse_params is not None and attn_cls is None: - raise ValueError("attn_cls is required when sparse_params is set.") - - if attn_cls is None: - attn_cls = get_attention_backend(backend_name) + attn_cls = get_attention_backend(backend_name, sparse_params=sparse_params) if is_mla_enable: assert attn_cls.support_mla( @@ -114,9 +97,8 @@ def create_attention( attention_chunk_size=attention_chunk_size, dtype=dtype, aux_stream=aux_stream, + sparse_params=sparse_params, ) - if sparse_params is not None: - kwargs["sparse_params"] = sparse_params return attn_cls( layer_idx, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 95c94a02425..80e9c67d74d 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -26,19 +26,6 @@ from .rotary_embedding import MRotaryEmbedding, RotaryEmbedding -def _lower_sparse_attention_params(sparse_attn_cfg, - pretrained_config=None, - layer_idx: Optional[int] = None): - if getattr(sparse_attn_cfg, "algorithm", None) == "deepseek_v4": - from tensorrt_llm._torch.attention_backend.sparse.deepseek_v4 import \ - make_deepseek_v4_sparse_params - - return make_deepseek_v4_sparse_params( - sparse_attn_cfg, pretrained_config=pretrained_config) - return sparse_attn_cfg.to_sparse_params(pretrained_config=pretrained_config, - layer_idx=layer_idx) - - def extract_extra_attrs(layer_idx: str, attn_type: str): assert attn_type in ["mla", "attn"], "Invalid attention type" extra_attrs = get_model_extra_attrs() @@ -609,13 +596,12 @@ def __init__( self.attn_backend = config.attn_backend sparse_attn_cfg = config.sparse_attention_config - sparse_params = (_lower_sparse_attention_params( - sparse_attn_cfg, + sparse_params = (sparse_attn_cfg.to_sparse_params( pretrained_config=config.pretrained_config, layer_idx=self.layer_idx) if sparse_attn_cfg is not None else None) - attn_cls = get_attention_backend( - self.attn_backend, sparse_attention_config=sparse_attn_cfg) + attn_cls = get_attention_backend(self.attn_backend, + sparse_params=sparse_params) self.is_marlin_enabled: bool = is_nvfp4_marlin_enabled() @@ -686,7 +672,6 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, q_scaling=self.q_scaling, attention_chunk_size=self.attention_chunk_size, - attn_cls=attn_cls, sparse_params=sparse_params, ) diff --git a/tensorrt_llm/_torch/modules/mla.py b/tensorrt_llm/_torch/modules/mla.py index f920021b6d8..c73518f2f46 100644 --- a/tensorrt_llm/_torch/modules/mla.py +++ b/tensorrt_llm/_torch/modules/mla.py @@ -44,7 +44,7 @@ DSAtrtllmAttentionMetadata, transform_local_topk_and_prepare_pool_view, ) -from ..attention_backend.utils import create_attention, get_attention_backend +from ..attention_backend.utils import create_attention from ..distributed import AllReduceParams from ..model_config import ModelConfig from ..utils import is_torch_compiling, maybe_compiled_cat, maybe_compiled_copy_ @@ -53,7 +53,6 @@ _helix_cp_output_projection, _helix_post_process, _helix_zero_kv_mask, - _lower_sparse_attention_params, extract_extra_attrs, ) from .linear import Linear, TensorParallelMode @@ -319,8 +318,7 @@ def __init__( config = config or ModelConfig() sparse_attn_cfg = config.sparse_attention_config sparse_params = ( - _lower_sparse_attention_params( - sparse_attn_cfg, + sparse_attn_cfg.to_sparse_params( pretrained_config=config.pretrained_config, layer_idx=self.layer_idx, ) @@ -557,8 +555,8 @@ def yarn_get_mscale(scale=1, mscale=1): self.has_dsv4_indexer = ( self.is_deepseek_v4 and layer_idx is not None - and config.sparse_attention_config is not None - and config.sparse_attention_config.compress_ratios[layer_idx] == 4 + and sparse_params is not None + and sparse_params.compress_ratios[layer_idx] == 4 ) self.indexer_stream = None self.indexer_aux_stream = None @@ -569,9 +567,6 @@ def yarn_get_mscale(scale=1, mscale=1): self.indexer_aux_stream if self.indexer_aux_stream is not None else aux_stream ) - mqa_cls = get_attention_backend( - config.attn_backend, sparse_attention_config=config.sparse_attention_config - ) self.mqa = create_attention( config.attn_backend, self.layer_idx, @@ -590,7 +585,6 @@ def yarn_get_mscale(scale=1, mscale=1): hidden_size=self.hidden_size, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, - attn_cls=mqa_cls, sparse_params=sparse_params, dtype=dtype, aux_stream=mqa_aux_stream, @@ -644,10 +638,7 @@ def yarn_get_mscale(scale=1, mscale=1): self.is_dsa and self.short_seq_mha_threshold > 0 and not self.apply_rotary_emb ) if (not self.is_dsa or _short_seq_mha) and not self.is_deepseek_v4: - mha_sparse_config = None if _short_seq_mha else config.sparse_attention_config - mha_cls = get_attention_backend( - config.attn_backend, sparse_attention_config=mha_sparse_config - ) + mha_sparse_params = None if _short_seq_mha else sparse_params self.mha = create_attention( config.attn_backend, self.layer_idx, @@ -665,8 +656,7 @@ def yarn_get_mscale(scale=1, mscale=1): v_head_dim=self.v_head_dim, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, - attn_cls=mha_cls, - sparse_params=(None if _short_seq_mha else sparse_params), + sparse_params=mha_sparse_params, ) else: self.mha = None diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index fc8ce77fe9f..f9742e4df2a 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -89,16 +89,26 @@ def get_kv_cache_manager_cls( * ``TRTLLM_USE_PY_MAMBA=1`` — Mixed manager with PythonMambaCacheManager. """ config = model_config.pretrained_config - sparse_attention_config = model_config.sparse_attention_config - if sparse_attention_config is not None: - return get_sparse_attn_kv_cache_manager(sparse_attention_config) - elif is_hybrid_linear(config): + sparse_attn_config = model_config.sparse_attention_config + sparse_attn_algorithm = getattr(sparse_attn_config, "algorithm", None) + if is_hybrid_linear(config): # Degenerate case: model is flagged as hybrid but the config has zero - # mamba layers. Fall through to the standard non-hybrid manager. + # mamba layers. Fall through to the standard non-hybrid routes. if model_config.get_num_mamba_layers() == 0: logger.info("Hybrid linear model has 0 mamba layers; using " - "KVCacheManager without mamba caching") + "KV cache manager without mamba caching") + if sparse_attn_config is not None: + return get_sparse_attn_kv_cache_manager(sparse_attn_config) return _non_hybrid_kv_cache_manager_cls(config, kv_cache_config) + + if (sparse_attn_config is not None + and sparse_attn_algorithm != "skip_softmax"): + raise ValueError( + f"Sparse attention algorithm {sparse_attn_algorithm!r} is not " + "supported with hybrid Mamba / linear-attention models.") + + # Skip Softmax only changes attention kernels. Hybrid models still + # need a Mamba-capable cache manager for recurrent state. if use_py_mamba_cache_manager(): if kv_cache_config.enable_block_reuse: raise ValueError( @@ -139,6 +149,8 @@ def get_kv_cache_manager_cls( f"Expected 'CPP' or 'MIXED'. Using default {default_cls.__name__}." ) return default_cls + elif sparse_attn_config is not None: + return get_sparse_attn_kv_cache_manager(sparse_attn_config) else: return _non_hybrid_kv_cache_manager_cls(config, kv_cache_config) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 9781bea589e..ab299e2b55c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -504,9 +504,11 @@ def __init__( dtype=torch.int, device='cuda') - self.attn_backend = get_attention_backend( - self.llm_args.attn_backend, - sparse_attention_config=self.sparse_attention_config) + sparse_params = (self.sparse_attention_config.to_sparse_params( + pretrained_config=self.model.model_config.pretrained_config) + if self.sparse_attention_config is not None else None) + self.attn_backend = get_attention_backend(self.llm_args.attn_backend, + sparse_params=sparse_params) self.spec_metadata = None if self.is_spec_decode: @@ -1839,7 +1841,7 @@ def _set_up_attn_metadata( # Cache the no-cache metadata. if self.encoder_attn_metadata is not None: return self.encoder_attn_metadata - metadata_kwargs = dict( + self.encoder_attn_metadata = metadata_cls( max_num_requests=self.batch_size, max_num_tokens=self.max_num_tokens, max_num_sequences=self.batch_size * self.max_beam_width, @@ -1852,7 +1854,6 @@ def _set_up_attn_metadata( cache_indirection=cache_indirection, num_heads_per_kv=num_heads_per_kv, sparse_metadata_params=sparse_metadata_params) - self.encoder_attn_metadata = metadata_cls(**metadata_kwargs) self.encoder_attn_metadata.block_ids_per_seq = None self.encoder_attn_metadata.kv_block_ids_per_seq = None return self.encoder_attn_metadata @@ -1863,7 +1864,7 @@ def _set_up_attn_metadata( assert self.attn_metadata.kv_cache_manager is kv_cache_manager return self.attn_metadata - metadata_kwargs = dict( + self.attn_metadata = metadata_cls( max_num_requests=self.batch_size, max_num_tokens=self.max_num_tokens, max_num_sequences=self.batch_size * self.max_beam_width, @@ -1877,7 +1878,6 @@ def _set_up_attn_metadata( num_heads_per_kv=num_heads_per_kv, sparse_metadata_params=sparse_metadata_params, ) - self.attn_metadata = metadata_cls(**metadata_kwargs) return self.attn_metadata diff --git a/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py b/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py index 1ab9acd30f6..2dac241f211 100644 --- a/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py +++ b/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py @@ -818,8 +818,8 @@ def forward_fn( result = self.transformer( hidden_states=latent_input, - timestep=timestep, - attention_timestep=timestep / self.scheduler.config.num_train_timesteps, + timestep=timestep / self.scheduler.config.num_train_timesteps, + raw_timestep=timestep, text_ids=extra_tensors["text_ids"], text_mask=extra_tensors["text_mask"], video_shape=video_shape, diff --git a/tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py b/tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py index 91f86ed978c..b3df78a06ec 100644 --- a/tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py +++ b/tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py @@ -972,7 +972,7 @@ def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None, - attention_timestep: Optional[torch.Tensor] = None, + raw_timestep: Optional[torch.Tensor] = None, text_ids: Optional[torch.Tensor] = None, text_mask: Optional[torch.Tensor] = None, video_shape: Optional[Tuple[int, int, int]] = None, @@ -986,9 +986,9 @@ def forward( Args: hidden_states: [B, C, T, H, W] noisy latents - timestep: Raw scheduler diffusion timestep, shape [B] - attention_timestep: Normalized diffusion timestep in [0, 1], shape [B], - for attention backends that use timestep-dependent behavior. + timestep: Normalized diffusion timestep in [0, 1], shape [B]. + raw_timestep: Raw scheduler diffusion timestep, shape [B], used by + the Cosmos3 time embedding path. text_ids: [B, S_text] tokenized text input text_mask: [B, S_text] attention mask for text (1=real, 0=pad) video_shape: (T, H, W) in latent space @@ -1009,6 +1009,10 @@ def forward( provided; otherwise None. action is always None for now. """ del kwargs # Kept for diffusers API compatibility. + if timestep is None: + raise ValueError("Cosmos3VFMTransformer.forward requires normalized timestep.") + if raw_timestep is None: + raise ValueError("Cosmos3VFMTransformer.forward requires raw_timestep.") T, H, W = video_shape Hp, Wp, _, _ = self._pad_to_patch_size(H, W) max_real_len = text_mask.sum(dim=1).max().item() @@ -1017,7 +1021,7 @@ def forward( hidden_gen = self.vae2llm(self.patchify(hidden_states, T, H, W)) with torch.autocast("cuda", enabled=True, dtype=torch.float32): - time_embed = self.time_embedder((timestep * self.timestep_scale)) + time_embed = self.time_embedder((raw_timestep * self.timestep_scale)) time_embed = time_embed.to(hidden_states.dtype) if noisy_frame_mask is not None: @@ -1048,7 +1052,7 @@ def forward( text_ids, text_mask, freqs_und, - timestep=attention_timestep, + timestep=timestep, ) self.cached_freqs_gen = freqs_gen @@ -1114,7 +1118,7 @@ def forward( k_und, v_und, freqs_gen, - timestep=attention_timestep, + timestep=timestep, real_text_lens=real_text_lens, ) else: @@ -1123,7 +1127,7 @@ def forward( k_und, v_und, freqs_gen, - timestep=attention_timestep, + timestep=timestep, ) hidden_gen = self.sharder.gather(hidden_gen, dim=1, unpad_to=S_gen) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 5cf75169b08..05fccb7f5a7 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -670,7 +670,6 @@ def forward( Args: timestep: Normalized scheduler timestep tensor in [0, 1]. """ - del kwargs # Kept for diffusers API compatibility. original_shape = hidden_states.shape B, C, T, H, W = original_shape pt, ph, pw = self.config.patch_size diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py index 8437e16995e..f63634f463d 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -605,18 +605,22 @@ def create_run_pack( ): world_size = mpi_world_size() pretrained_config = self.model_config.pretrained_config + sparse_attention_config = self.model_config.sparse_attention_config + sparse_params = ( + sparse_attention_config.to_sparse_params(pretrained_config=pretrained_config) + if sparse_attention_config is not None + else None + ) AttentionCls = get_attention_backend( - self.model_config.attn_backend, - sparse_attention_config=self.model_config.sparse_attention_config, + self.model_config.attn_backend, sparse_params=sparse_params ) metadata_cls = AttentionCls.Metadata - sparse_attention_config = self.model_config.sparse_attention_config sparse_metadata_params = ( sparse_attention_config.to_sparse_metadata_params(pretrained_config=pretrained_config) if sparse_attention_config is not None else None ) - metadata_kwargs = dict( + attn_metadata = metadata_cls( seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int), request_ids=list(range(request_id_begin, request_id_begin + batch_size)), max_num_requests=kv_cache_manager.max_batch_size, @@ -641,7 +645,6 @@ def create_run_pack( mapping=self.model_config.mapping, sparse_metadata_params=sparse_metadata_params, ) - attn_metadata = metadata_cls(**metadata_kwargs) attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size attn_metadata.prepare() hidden_size = pretrained_config.hidden_size diff --git a/tests/unittest/_torch/attention/sparse/dsa/test_short_seq_mha.py b/tests/unittest/_torch/attention/sparse/dsa/test_short_seq_mha.py index 6286c3198a1..36a788d3faf 100644 --- a/tests/unittest/_torch/attention/sparse/dsa/test_short_seq_mha.py +++ b/tests/unittest/_torch/attention/sparse/dsa/test_short_seq_mha.py @@ -268,6 +268,11 @@ def _build_mla(rope_config, device, threshold): return mla, mapping, sparse_config, model_config +def _attention_cls_for_sparse_config(sparse_config, model_config): + sparse_params = sparse_config.to_sparse_params(pretrained_config=model_config.pretrained_config) + return get_attention_backend("TRTLLM", sparse_params=sparse_params) + + def _init_mla_weights(mla): """Initialize MLA weights deterministically in the loaded layout. @@ -417,7 +422,7 @@ def test_forward_context_short_mha(name: str, seq_lens: List[int], threshold_off _init_mla_weights(mla) kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) - attn_cls = get_attention_backend("TRTLLM", sparse_attention_config=sparse_config) + attn_cls = _attention_cls_for_sparse_config(sparse_config, model_config) q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) metadata = _make_metadata(attn_cls, seq_lens, kv_mgr, mapping, sparse_config) @@ -457,7 +462,7 @@ def test_standard_path_when_exceeds_threshold(): _init_mla_weights(mla) kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) - attn_cls = get_attention_backend("TRTLLM", sparse_attention_config=sparse_config) + attn_cls = _attention_cls_for_sparse_config(sparse_config, model_config) q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) total_tokens = sum(seq_lens) @@ -511,7 +516,7 @@ def test_agrees_with_absorption_path(): q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(seq_lens, device) hidden_states = torch.randn(total_tokens, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) qr = torch.randn(total_tokens, Q_LORA_RANK, dtype=torch.bfloat16, device=device) - attn_cls = get_attention_backend("TRTLLM", sparse_attention_config=sparse_config) + attn_cls = _attention_cls_for_sparse_config(sparse_config, model_config) def _run(mla_module): kv_mgr = _build_kv_cache_manager(mapping, sparse_config, model_config, seq_lens, device) @@ -559,7 +564,7 @@ def test_chunked_correctness(name: str, chunk_specs: List[Tuple[int, int]], chun mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold) _init_mla_weights(mla) - attn_cls = get_attention_backend("TRTLLM", sparse_attention_config=sparse_config) + attn_cls = _attention_cls_for_sparse_config(sparse_config, model_config) q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(total_per_seq, device) @@ -640,7 +645,7 @@ def test_chunked_context_rejects_when_kv_exceeds_threshold(): mla, mapping, sparse_config, model_config = _build_mla(rope_config, device, threshold) _init_mla_weights(mla) - attn_cls = get_attention_backend("TRTLLM", sparse_attention_config=sparse_config) + attn_cls = _attention_cls_for_sparse_config(sparse_config, model_config) q, compressed_kv, k_pe, latent_cache, position_ids = _make_inputs(total_per_seq, device) diff --git a/tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py b/tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py index 5772aae84a3..39a344c8cc3 100644 --- a/tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py +++ b/tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py @@ -1916,8 +1916,9 @@ def yarn_get_mscale(scale=1, mscale=1): **kv_cache_manager_kwargs, ) - AttentionCls = get_attention_backend("TRTLLM", - sparse_attention_config=sparse_config) + sparse_params = sparse_config.to_sparse_params( + pretrained_config=model_config.pretrained_config) + AttentionCls = get_attention_backend("TRTLLM", sparse_params=sparse_params) # Allocate and pre-populate KV cache in batch order [context...][generation...] diff --git a/tests/unittest/_torch/attention/test_attention_op_sync.py b/tests/unittest/_torch/attention/test_attention_op_sync.py index f85667b5a85..201a8ebc92a 100644 --- a/tests/unittest/_torch/attention/test_attention_op_sync.py +++ b/tests/unittest/_torch/attention/test_attention_op_sync.py @@ -50,7 +50,6 @@ FallbackFmha, ) from tensorrt_llm._torch.attention_backend.interface import AttentionForwardArgs -from tensorrt_llm._torch.attention_backend.sparse.skip_softmax import SkipSoftmaxKernelParams from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttention, TrtllmAttentionMetadata # Roots used as the LHS of attribute chains at the call site. Match the @@ -59,11 +58,9 @@ "attn": TrtllmAttention, "metadata": TrtllmAttentionMetadata, "forward_args": AttentionForwardArgs, - "skip_softmax_kernel_params": SkipSoftmaxKernelParams, } _THOP_KWARG_SOURCE_ALIASES: dict[str, tuple[str, tuple[str, ...]]] = { - "beam_width": ("metadata", ("effective_beam_width",)), "context_lengths": ("metadata", ("prompt_lens_cuda_runtime",)), "head_size": ("attn", ("head_dim",)), "host_context_lengths": ("metadata", ("prompt_lens_cpu_runtime",)), @@ -74,6 +71,20 @@ "metadata", ("max_total_draft_tokens",), ), + "skip_softmax_threshold_scale_factor_decode": ( + "forward_args", + ( + "skip_softmax_kernel_params", + "threshold_scale_factor_decode", + ), + ), + "skip_softmax_threshold_scale_factor_prefill": ( + "forward_args", + ( + "skip_softmax_kernel_params", + "threshold_scale_factor_prefill", + ), + ), "workspace_": ("metadata", ("effective_workspace",)), } diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_cosmos3_transformer_parallel.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_cosmos3_transformer_parallel.py index 38e8675c3f0..e5b770f5dfa 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_cosmos3_transformer_parallel.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_cosmos3_transformer_parallel.py @@ -108,6 +108,7 @@ _TEXT_LEN = 8 _MAX_TEXT_LEN = 16 _TIMESTEP = 500.0 +_NUM_TRAIN_TIMESTEPS = 1000.0 _FPS = 24.0 # Audio (sound) modality: audio tokens are appended to the gen sequence, so the @@ -387,7 +388,8 @@ def _forward(model: Cosmos3VFMTransformer, device: torch.device, text_seed: int) with torch.inference_mode(): return model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -413,7 +415,8 @@ def _forward_with_audio( with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, diff --git a/tests/unittest/_torch/visual_gen/test_attention_trtllm_sage.py b/tests/unittest/_torch/visual_gen/test_attention_trtllm_sage.py index 40845a19da7..c5faafe7161 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_trtllm_sage.py +++ b/tests/unittest/_torch/visual_gen/test_attention_trtllm_sage.py @@ -88,7 +88,7 @@ def _test_attention_trtllm_sage( sparse_attention_config.to_sparse_params() if sparse_attention_config is not None else None ) attention_cls = attention_backend_utils.get_attention_backend( - "TRTLLM", sparse_attention_config=sparse_attention_config + "TRTLLM", sparse_params=sparse_params ) attention = attention_cls( layer_idx=0, diff --git a/tests/unittest/_torch/visual_gen/test_cosmos3_transformer.py b/tests/unittest/_torch/visual_gen/test_cosmos3_transformer.py index 5d1100b394b..76e6c1a01a8 100644 --- a/tests/unittest/_torch/visual_gen/test_cosmos3_transformer.py +++ b/tests/unittest/_torch/visual_gen/test_cosmos3_transformer.py @@ -73,6 +73,7 @@ def _checkpoint(env_var: str, default_name: str) -> str: DEVICE = "cuda" DTYPE = torch.bfloat16 +_NUM_TRAIN_TIMESTEPS = 1000.0 COSMOS3_FP8_QUANT_CONFIG = { "quant_algo": "FP8", @@ -216,7 +217,8 @@ def test_sanity_forward(self, cosmos3_model_config): with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -233,14 +235,16 @@ def test_reset_cache(self, cosmos3_model_config): with torch.inference_mode(): out1 = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, ) out2 = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -261,7 +265,8 @@ def test_sanity_forward_i2v_mask(self, cosmos3_model_config): with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -335,7 +340,8 @@ def test_forward_with_audio(self, audio_model_config): with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -358,7 +364,8 @@ def test_forward_without_audio_latents_returns_none(self, audio_model_config): with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -378,7 +385,8 @@ def test_forward_with_audio_multiframe(self, audio_model_config): with torch.inference_mode(): out = model( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -419,7 +427,8 @@ def test_load_weights_and_forward(self, cosmos3_transformer): with torch.inference_mode(): out = transformer( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, @@ -448,7 +457,8 @@ def test_load_fp8_quantization(self, quant_algo: str): with torch.inference_mode(): out = transformer( hidden_states=hs, - timestep=ts, + timestep=ts / _NUM_TRAIN_TIMESTEPS, + raw_timestep=ts, text_ids=text_ids, text_mask=text_mask, video_shape=video_shape, diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 69bab9c13e0..ceeaf8910c9 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -2855,6 +2855,17 @@ def test_zero_compress_ratios_are_normalized(self): assert config.compress_ratios == [1, 4, 128] + def test_lowers_to_deepseek_v4_sparse_params(self): + config = DeepSeekV4SparseAttentionConfig(compress_ratios=[0, 4, 128]) + + sparse_params = config.to_sparse_params() + sparse_metadata_params = config.to_sparse_metadata_params() + + assert sparse_params.algorithm == "deepseek_v4" + assert sparse_params.compress_ratios == [1, 4, 128] + assert sparse_metadata_params.compress_ratios == [1, 4, 128] + assert sparse_metadata_params.window_size == 128 + @pytest.mark.parametrize("compress_ratios", [[], [-1, 4, 128]]) def test_invalid_compress_ratios_raise(self, compress_ratios): with pytest.raises(ValidationError, match="compress_ratios"):