Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,10 +673,10 @@ def drafting_loop_wrapper(model):

sm_version = get_sm_version()
if kv_cache_config.enable_block_reuse and sm_version not in [
90, 100, 103, 120
90, 100, 103, 120, 121
]:
logger.warning(
f"KV cache reuse for MLA can only be enabled on SM90/SM100/SM103/SM120, "
f"KV cache reuse for MLA can only be enabled on SM90/SM100/SM103/SM120/SM121, "
f"disable enable_block_reuse for SM{sm_version}")
kv_cache_config.enable_block_reuse = False
_set_model_engines_cache_reuse([model_engine, draft_model_engine],
Expand All @@ -693,9 +693,11 @@ def drafting_loop_wrapper(model):
kv_cache_config.enable_block_reuse = False
_set_model_engines_cache_reuse([model_engine, draft_model_engine],
False)
if enable_chunked_context and sm_version not in [90, 100, 103, 120]:
if enable_chunked_context and sm_version not in [
90, 100, 103, 120, 121
]:
logger.warning(
"Chunked Prefill for MLA can only be enabled on SM90/SM100/SM103/SM120, "
"Chunked Prefill for MLA can only be enabled on SM90/SM100/SM103/SM120/SM121, "
f"disable enable_chunked_context for SM{sm_version}")
enable_chunked_context = False
model_engine.attn_runtime_features.chunked_prefill = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ def _make_llm_args():
)


def _run_create_py_executor(monkeypatch, *, sm_version, kv_cache_quant_algo):
"""Execute create_py_executor with mocked dependencies and return cache reuse flags.
def _run_create_py_executor(
monkeypatch, *, sm_version, kv_cache_quant_algo, enable_chunked_prefill=False
):
"""Execute create_py_executor with mocked dependencies and return MLA runtime flags.

Mocks all external dependencies (model engine, resource managers, etc.) to isolate
executor creation logic and verify that KV cache reuse configuration is synchronized
Expand All @@ -196,11 +198,14 @@ def _run_create_py_executor(monkeypatch, *, sm_version, kv_cache_quant_algo):
monkeypatch: pytest fixture for mocking.
sm_version: CUDA SM version to simulate (e.g., 89, 90).
kv_cache_quant_algo: Quantization algorithm to use (e.g., NO_QUANT, INT8).
enable_chunked_prefill: Whether to request MLA chunked prefill support.

Returns:
Tuple of (kv_cache_reuse_flag, runtime_cache_reuse_flag) from created executor.
Tuple of (kv_cache_reuse_flag, runtime_cache_reuse_flag,
runtime_chunked_prefill_flag) from created executor.
"""
llm_args = _make_llm_args()
llm_args.enable_chunked_prefill = enable_chunked_prefill
fake_mapping = SimpleNamespace(
rank=0,
tp_size=1,
Expand Down Expand Up @@ -276,6 +281,7 @@ def _create_py_executor_instance(**kwargs):
return (
kv_cache_manager.enable_block_reuse,
py_executor.model_engine.attn_runtime_features.cache_reuse,
py_executor.model_engine.attn_runtime_features.chunked_prefill,
)


Expand All @@ -288,7 +294,7 @@ def test_mla_unsupported_sm_fallback_syncs_cache_reuse(monkeypatch):

This test ensures invariant synchronization is maintained across the fallback.
"""
kv_cache_reuse, runtime_cache_reuse = _run_create_py_executor(
kv_cache_reuse, runtime_cache_reuse, _ = _run_create_py_executor(
monkeypatch,
sm_version=89,
kv_cache_quant_algo=QuantAlgo.NO_QUANT,
Expand All @@ -308,7 +314,7 @@ def test_mla_unsupported_kv_quant_fallback_syncs_cache_reuse(monkeypatch):

This test ensures invariant synchronization is maintained across the fallback.
"""
kv_cache_reuse, runtime_cache_reuse = _run_create_py_executor(
kv_cache_reuse, runtime_cache_reuse, _ = _run_create_py_executor(
monkeypatch,
sm_version=90,
kv_cache_quant_algo=QuantAlgo.INT8,
Expand All @@ -328,11 +334,48 @@ def test_mla_supported_configuration_preserves_cache_reuse(monkeypatch):

This positive test ensures the default path does not regress.
"""
kv_cache_reuse, runtime_cache_reuse = _run_create_py_executor(
kv_cache_reuse, runtime_cache_reuse, _ = _run_create_py_executor(
monkeypatch,
sm_version=90,
kv_cache_quant_algo=QuantAlgo.NO_QUANT,
)

assert kv_cache_reuse is True
assert runtime_cache_reuse is True


def test_mla_sm121_supported_configuration_preserves_cache_reuse(monkeypatch):
"""Verify MLA support on SM121 preserves cache reuse in both config and runtime.

When SM121 is treated as a supported MLA architecture and KV quantization is
NO_QUANT, no unsupported-SM fallback should occur and:
- kv_cache_config.enable_block_reuse remains True
- model_engine.attn_runtime_features.cache_reuse remains True

This regression test protects the SM121 allowlist expansion added for MLA
cache reuse support.
"""
kv_cache_reuse, runtime_cache_reuse, _ = _run_create_py_executor(
monkeypatch,
sm_version=121,
kv_cache_quant_algo=QuantAlgo.NO_QUANT,
)

assert kv_cache_reuse is True
assert runtime_cache_reuse is True


def test_mla_sm121_supported_configuration_preserves_chunked_prefill(monkeypatch):
"""Verify MLA support on SM121 preserves chunked prefill when it is requested.

When SM121 is treated as a supported MLA architecture and chunked prefill is
enabled, the unsupported-SM fallback should not disable the runtime feature.
"""
_, _, runtime_chunked_prefill = _run_create_py_executor(
monkeypatch,
sm_version=121,
kv_cache_quant_algo=QuantAlgo.NO_QUANT,
enable_chunked_prefill=True,
)

assert runtime_chunked_prefill is True