Skip to content

[FMHA] gfx950 dualwave SWP with split-K, varlen, and arbitrary seq_len#681

Closed
yanguahe wants to merge 2 commits into
mainfrom
refine_fmha
Closed

[FMHA] gfx950 dualwave SWP with split-K, varlen, and arbitrary seq_len#681
yanguahe wants to merge 2 commits into
mainfrom
refine_fmha

Conversation

@yanguahe

Copy link
Copy Markdown
Contributor
  • Add flash_attn_dualwave_swp_gfx950_kernel with lazy-rescale, s_setprio stagger, split-K combine path, and buffer_store_dwordx4 O-store
  • Support packed QKV varlen via cu_seqlens; arbitrary seq_len >= 1 on both dualwave and generic fallback paths with padding masks
  • Update flash_attn_generic dispatch, seq_len guard, and varlen routing
  • Extend test_flash_attn_fwd with split-K, varlen configs, OPUS/aiter compare

Ported from opus_align FMHA optimization work onto rocm/main base.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

- Add flash_attn_dualwave_swp_gfx950_kernel with lazy-rescale, s_setprio
  stagger, split-K combine path, and buffer_store_dwordx4 O-store
- Support packed QKV varlen via cu_seqlens; arbitrary seq_len >= 1 on both
  dualwave and generic fallback paths with padding masks
- Update flash_attn_generic dispatch, seq_len guard, and varlen routing
- Extend test_flash_attn_fwd with split-K, varlen configs, OPUS/aiter compare

Ported from opus_align FMHA optimization work onto rocm/main base.

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderfeli

Copy link
Copy Markdown
Collaborator

CI FAILED @yanguahe

The generic flash_attn O-store used permlane32_swap and cvt_pk_bf16_f32
(both gfx950/CDNA4-only) unconditionally. On gfx942 (CDNA3) the gfx950
dualwave fast path is disabled and flash_attn falls back to the generic
kernel, so the backend hit "Cannot select intrinsic
llvm.amdgcn.permlane32.swap" and aborted (CI: test linux-flydsl-mi325-1).

Gate the 128-bit permlane-fused store behind gfx950; gfx942 falls back to a
per-lane dwordx2 store packed via .to(elem_dtype) (arch-correct bf16/f16
conversion, same column layout, still num_records-bounded for OOB rows).
Add FLYDSL_DISABLE_DUALWAVE_SWP / FLYDSL_GENERIC_OSTORE_SCALAR env hooks to
exercise the generic kernel and its gfx942 store path on gfx950 hardware.

Verified on gfx950 (MI355): the permlane and scalar O-store paths both give
MaxErr 3.91e-3 vs SDPA across H8/16/64, GQA, and partial-seqlen configs; the
default gfx950 dualwave path is unchanged (PASS, MaxErr 3.91e-3).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@yanguahe yanguahe closed this Jun 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants