diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 190b3cc85906..266a38f837ab 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -347,6 +347,29 @@ def split_prefill_chunks( return chunk_groups +# Shrink the indexer prefill chunk size for very long requests to bound the +# fp8(_fp4)_mqa_logits activation memory (~ chunk_size * K_compressed), keyed on +# the largest compressed KV length in the batch. Entries are +# (k_compressed_lower_bound_exclusive, chunk_size); >512K -> 8K, [256K, 512K] +# -> 16K, otherwise unchanged. +_INDEXER_CHUNK_SIZE_HEURISTIC = ( + (512 * 1024, 8 * 1024), + (256 * 1024 - 1, 16 * 1024), +) + + +def select_indexer_chunk_size(configured_chunk_size: int, + max_k_compressed: int) -> int: + """Pick the indexer prefill chunk size from the batch's largest K_compressed. + + Only reduces ``configured_chunk_size`` (never increases it). + """ + for threshold, chunk_size in _INDEXER_CHUNK_SIZE_HEURISTIC: + if max_k_compressed > threshold: + return min(configured_chunk_size, chunk_size) + return configured_chunk_size + + def _select_indexer_compress_ratio(compress_ratios: List[int]) -> int: if 4 in compress_ratios: return 4 @@ -1778,9 +1801,15 @@ def prepare_for_chunked_prefill(metadata: DSAtrtllmAttentionMetadata, else: # Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences. # This is only used when MLA chunked prefill is not enabled. + # Adapt chunk size to the batch's largest K_compressed (see + # select_indexer_chunk_size). + max_k_compressed = int(indexer_params.kv_lens[:num_contexts].max(). + item()) if num_contexts > 0 else 0 + effective_chunk_size = select_indexer_chunk_size( + metadata.indexer_max_chunk_size, max_k_compressed) chunk_groups = split_prefill_chunks( seq_lens[:num_contexts], - metadata.indexer_max_chunk_size, + effective_chunk_size, start_idx=0, )