Skip to content

[Perf] flash_attn gfx950: dwordx4 O stores + flash-decoding split-K#670

Draft
jhinpan wants to merge 3 commits into
ROCm:mainfrom
jhinpan:perf/fa-gfx950-ostore-dwordx4
Draft

[Perf] flash_attn gfx950: dwordx4 O stores + flash-decoding split-K#670
jhinpan wants to merge 3 commits into
ROCm:mainfrom
jhinpan:perf/fa-gfx950-ostore-dwordx4

Conversation

@jhinpan

@jhinpan jhinpan commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

What

Two complementary optimizations to the gfx950 dual-wave flash-attn fwd:

  1. O epilogue stores dwordx2 → dwordx4permlane32_swap fuses each lane's 4 cols with its half-wave partner's 4 into one 16B store (8 stores/wave instead of 16). rocprofv3 ATT: vmem_store stall −78%. Pays where the epilogue is exposed (grid ≈ 1 WG/CU, short-S causal).
  2. Flash-decoding split-K (num_kv_splits builder param). KV chunks across grid z; epilogue stores normalized bf16/fp16 partials + fp32 LSE; small combine kernel rescales+merges (empty splits skipped). splits=1 is byte-identical to baseline (ISA diff). Dispatch: splits = clamp_even(CUs/(B·H·⌈S/256⌉), ≤4) → 1 for all standard prefill shapes.

A/B vs main (same day, same harness: test_flash_attn_fwd.py kernel-only, warmup10/iters100, MI350X, causal D128; main @3e7f66e7)

Prefill (splits=1):

shape main PR Δ
B1 S2048 H32 bf16 621 636 +2.3%
B4 S2048 GQA8 bf16 697 707 +1.5%
B1 S2048 H32 fp16 593 604 +1.9%
B8 S1024 H32 bf16 521 523 +0.3%
B1 S8192 H32 bf16 892 896 +0.5%
B1 S4096 H32 bf16 775 760 −1.9% (run noise; repeats overlap 760–780 both arms)

Low-grid, split-K active (kernel+combine vs split1, same arm; main has no split path):

shape split1 split-K speedup
B1 S8192 H2 (sp4) 234 569 2.43×
B1 S4096 H2 (sp4) 111 271 2.43×
B1 S2048 H4 (sp4) 101 177 1.76×
B1 S8192 H4 (sp2) 466 672 1.44×

Correctness

Full test_flash_attn_fwd.py sweep PASS (bf16+fp16, causal+non-causal, incl. B8·S512·H64). Split-K ≤3.2e-3 bf16 / 3.6e-4 fp16 vs torch fp32 SDPA — identical to split1 (partials normalized before 16-bit pack → zero drift). VGPR ≤256, scratch 0.

🤖 Generated with Claude Code

jhinpan and others added 2 commits June 10, 2026 04:00
…swap

vmem_store stall -78% (ATT), gmean +2.9% across 6 frozen configs, no regression.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Per-q-block KV split (num_kv_splits, splits=1 byte-identical); normalized
bf16/fp16 O partials + fp32 LSE in workspace; combine skips empty splits.
Low-grid shapes (B*H*qblocks < CUs): 1.4x-2.4x bf16; frozen prefill set
unchanged. Wiki: technique-split-k (mla-decode flash-decoding schedule).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@jhinpan jhinpan changed the title [Perf] flash_attn gfx950: widen O epilogue stores to dwordx4 via permlane32_swap [Perf] flash_attn gfx950: dwordx4 O stores + flash-decoding split-K Jun 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant