Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ def split_prefill_chunks(
return chunk_groups


# Shrink the indexer prefill chunk size for very long requests to bound the
# fp8(_fp4)_mqa_logits activation memory (~ chunk_size * K_compressed), keyed on
# the largest compressed KV length in the batch. Entries are
# (k_compressed_lower_bound_exclusive, chunk_size); >512K -> 8K, [256K, 512K]
# -> 16K, otherwise unchanged.
_INDEXER_CHUNK_SIZE_HEURISTIC = (
(512 * 1024, 8 * 1024),
(256 * 1024 - 1, 16 * 1024),
)


def select_indexer_chunk_size(configured_chunk_size: int,
max_k_compressed: int) -> int:
"""Pick the indexer prefill chunk size from the batch's largest K_compressed.

Only reduces ``configured_chunk_size`` (never increases it).
"""
for threshold, chunk_size in _INDEXER_CHUNK_SIZE_HEURISTIC:
if max_k_compressed > threshold:
return min(configured_chunk_size, chunk_size)
return configured_chunk_size


def _select_indexer_compress_ratio(compress_ratios: List[int]) -> int:
if 4 in compress_ratios:
return 4
Expand Down Expand Up @@ -1778,9 +1801,15 @@ def prepare_for_chunked_prefill(metadata: DSAtrtllmAttentionMetadata,
else:
# Use indexer's own chunking logic to prevent L^2 complexity of indexer MQA logits computation for long sequences.
# This is only used when MLA chunked prefill is not enabled.
# Adapt chunk size to the batch's largest K_compressed (see
# select_indexer_chunk_size).
max_k_compressed = int(indexer_params.kv_lens[:num_contexts].max().
item()) if num_contexts > 0 else 0
effective_chunk_size = select_indexer_chunk_size(
metadata.indexer_max_chunk_size, max_k_compressed)
chunk_groups = split_prefill_chunks(
seq_lens[:num_contexts],
metadata.indexer_max_chunk_size,
effective_chunk_size,
start_idx=0,
)

Expand Down
Loading