Skip to content

[cuda backend] skip fully-masked KV blocks calculation in SDPA#20198

Open
Gasoonjia wants to merge 7 commits into
mainfrom
g4-opt-prefill-window-sdpa
Open

[cuda backend] skip fully-masked KV blocks calculation in SDPA#20198
Gasoonjia wants to merge 7 commits into
mainfrom
g4-opt-prefill-window-sdpa

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

This PR skip attention calculations entirely masked by

  1. sliding-window
  2. causal via start_n > max_seq_pos (lower tranigle of attention calculation)

result:

prompt len before after gain (vs before) llama.cpp (cuBLAS prefill t/s) gain (vs llama.cpp)
256 729.9 1065.7 +46% 1150.5 −7.4%
512 935.2 1582.7 +69% 1609.9 −1.7%
1024 1116.9 2024.5 +81% 1601.2 +26.4%
2048 1238.0 2325.5 +88% 1573.7 +47.8%

Prefill +46-88% all lengths; decode remain the same;
Surpass llama.cpp by 48% on 2k input while on par (-7.4%) on short prompt
Profling on gemma4-31b: SDPA takes 58.1% ->18.5% of e2e prefilling time.
Numerically bf16-exact vs dense+mask (unit test).

Gasoonjia and others added 6 commits June 8, 2026 22:15
…decode

Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the
[N, n_groups] layout into the weight constant at pack time. Introduces
CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the
[n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm
dispatch on it by type, and adds the coalesced dp4a matvec kernel that
reads scale/zero row-for-row with qdata (single coalesced load vs 32
stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32.

Rebased onto main; INT8 dp4a decode op and the floor_div pass from this
branch landed separately and now live in quantize_op_dispatch/.
…ied) + benchmark rework

Summary:
At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to
CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding
partitions the KV sequence across many more CTAs to fill the GPU. In
ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when
L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head
dims keep the standard kernel).

The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful
to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to
a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call
partial-buffer alloc + extra-launch overhead that the graph runtime eliminates;
under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward.

benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery;
run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch,
Flash, Efficient, Math) with the PyTorch correctness check, across several
decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv
range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing
every backend through a small self-contained cuda-graph primitive; terminal-only
output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4
discarded as warmup; N_RUNS=10, N_WARMUP=4).

Test Plan:
Exercised against the repo (PYTHONPATH) since the conda env's installed
executorch is stale; a lib reinstall is required for the routing to take effect
in a real export.

backends/cuda/tests/test_sdpa_splitk_replacement.py
  - L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K;
    non-pow2 D=96 -> standard.
backends/cuda/tests/test_triton_sdpa_splitk.py (14) and
backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total.

gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0,
--cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98
tok/s (+16.0%); prefill within noise.

python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma
D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 /
PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 /
ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's
per-call overhead (the artifact behind the old 2048).
…ock)

The decode-only int4_plain_mm matvec was bound by activation load-instruction
throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner
iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation
loads + the same per-block scale d reloaded 4x.

Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B
halves are 16B-aligned, then load a whole activation block with two vectorized
uint4 loads + one d load (~4x fewer activation loads). dp4a math and
accumulation order are bit-identical; the int8 activation values and scale are
unchanged.

gemma4_31b decode (long-ctx harness, stacked on optimize_1):
  decode  43.98 -> 46.79 tok/s (+6.4%)
  prefill 1193  -> 1186     (noise; int4_plain_mm is decode-only)
nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged.
Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128).
Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are
entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos).
Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe;
SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test).
@pytorch-bot

pytorch-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20198

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 2 Pending

As of commit 8029941 with merge base c5bf380 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 10, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 10, 2026

Copy link
Copy Markdown

CLA Missing ID

@Gasoonjia Gasoonjia changed the title [cuda][prefill] window-aware SDPA: skip fully-masked KV blocks (idea #1) [cuda][prefill] window-aware SDPA: skip fully-masked KV blocks Jun 10, 2026
@Gasoonjia Gasoonjia changed the title [cuda][prefill] window-aware SDPA: skip fully-masked KV blocks [cuda backend]window-aware SDPA: skip fully-masked KV blocks Jun 10, 2026
@Gasoonjia Gasoonjia marked this pull request as ready for review June 10, 2026 21:35
@Gasoonjia Gasoonjia changed the title [cuda backend]window-aware SDPA: skip fully-masked KV blocks [cuda backend] skip fully-masked KV blocks calculation in SDPA Jun 10, 2026
@Gasoonjia Gasoonjia requested a review from mergennachin June 10, 2026 21:40
@mergennachin mergennachin requested a review from digantdesai June 10, 2026 21:46
Base automatically changed from g4-opt-int4-vecload to main June 12, 2026 07:37
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant