[cuda backend] skip fully-masked KV blocks calculation in SDPA#20198
[cuda backend] skip fully-masked KV blocks calculation in SDPA#20198Gasoonjia wants to merge 7 commits into
Conversation
…decode Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the [N, n_groups] layout into the weight constant at pack time. Introduces CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the [n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm dispatch on it by type, and adds the coalesced dp4a matvec kernel that reads scale/zero row-for-row with qdata (single coalesced load vs 32 stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32. Rebased onto main; INT8 dp4a decode op and the floor_div pass from this branch landed separately and now live in quantize_op_dispatch/.
…ied) + benchmark rework
Summary:
At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to
CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding
partitions the KV sequence across many more CTAs to fill the GPU. In
ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when
L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head
dims keep the standard kernel).
The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful
to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to
a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call
partial-buffer alloc + extra-launch overhead that the graph runtime eliminates;
under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward.
benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery;
run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch,
Flash, Efficient, Math) with the PyTorch correctness check, across several
decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv
range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing
every backend through a small self-contained cuda-graph primitive; terminal-only
output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4
discarded as warmup; N_RUNS=10, N_WARMUP=4).
Test Plan:
Exercised against the repo (PYTHONPATH) since the conda env's installed
executorch is stale; a lib reinstall is required for the routing to take effect
in a real export.
backends/cuda/tests/test_sdpa_splitk_replacement.py
- L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K;
non-pow2 D=96 -> standard.
backends/cuda/tests/test_triton_sdpa_splitk.py (14) and
backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total.
gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0,
--cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98
tok/s (+16.0%); prefill within noise.
python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma
D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 /
PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 /
ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's
per-call overhead (the artifact behind the old 2048).
…ock) The decode-only int4_plain_mm matvec was bound by activation load-instruction throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation loads + the same per-block scale d reloaded 4x. Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B halves are 16B-aligned, then load a whole activation block with two vectorized uint4 loads + one d load (~4x fewer activation loads). dp4a math and accumulation order are bit-identical; the int8 activation values and scale are unchanged. gemma4_31b decode (long-ctx harness, stacked on optimize_1): decode 43.98 -> 46.79 tok/s (+6.4%) prefill 1193 -> 1186 (noise; int4_plain_mm is decode-only) nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged. Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128).
Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos). Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe; SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test).
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20198
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 2 PendingAs of commit 8029941 with merge base c5bf380 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This PR needs a
|
This PR skip attention calculations entirely masked by
result:
Prefill +46-88% all lengths; decode remain the same;
Surpass llama.cpp by 48% on 2k input while on par (-7.4%) on short prompt
Profling on gemma4-31b: SDPA takes 58.1% ->18.5% of e2e prefilling time.
Numerically bf16-exact vs dense+mask (unit test).