[Perf] flash_attn gfx950: dwordx4 O stores + flash-decoding split-K#670
Draft
jhinpan wants to merge 3 commits into
Draft
[Perf] flash_attn gfx950: dwordx4 O stores + flash-decoding split-K#670jhinpan wants to merge 3 commits into
jhinpan wants to merge 3 commits into
Conversation
…swap vmem_store stall -78% (ATT), gmean +2.9% across 6 frozen configs, no regression. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Per-q-block KV split (num_kv_splits, splits=1 byte-identical); normalized bf16/fp16 O partials + fp32 LSE in workspace; combine skips empty splits. Low-grid shapes (B*H*qblocks < CUs): 1.4x-2.4x bf16; frozen prefill set unchanged. Wiki: technique-split-k (mla-decode flash-decoding schedule). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This was referenced Jun 11, 2026
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Two complementary optimizations to the gfx950 dual-wave flash-attn fwd:
permlane32_swapfuses each lane's 4 cols with its half-wave partner's 4 into one 16B store (8 stores/wave instead of 16). rocprofv3 ATT: vmem_store stall −78%. Pays where the epilogue is exposed (grid ≈ 1 WG/CU, short-S causal).num_kv_splitsbuilder param). KV chunks across grid z; epilogue stores normalized bf16/fp16 partials + fp32 LSE; small combine kernel rescales+merges (empty splits skipped).splits=1is byte-identical to baseline (ISA diff). Dispatch:splits = clamp_even(CUs/(B·H·⌈S/256⌉), ≤4)→ 1 for all standard prefill shapes.A/B vs main (same day, same harness: test_flash_attn_fwd.py kernel-only, warmup10/iters100, MI350X, causal D128; main @3e7f66e7)
Prefill (splits=1):
Low-grid, split-K active (kernel+combine vs split1, same arm; main has no split path):
Correctness
Full
test_flash_attn_fwd.pysweep PASS (bf16+fp16, causal+non-causal, incl. B8·S512·H64). Split-K ≤3.2e-3 bf16 / 3.6e-4 fp16 vs torch fp32 SDPA — identical to split1 (partials normalized before 16-bit pack → zero drift). VGPR ≤256, scratch 0.🤖 Generated with Claude Code