Skip to content

Conversation

ch1y0q
Copy link
Contributor

@ch1y0q ch1y0q commented Aug 14, 2025

PR Category

Operator | OP Test

Type of Change

New Feature

Description

Add backward of SDPA, which is still buggy when seq_len is not multiply of 64.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@ch1y0q
Copy link
Contributor Author

ch1y0q commented Aug 14, 2025

Currently, dq, dk, and dv exhibit bugs when seq_len is not a multiple of 64. We'll prioritize fixing dv first, as the first head of dv appears correct.

TODO

  • Resolve issues with dk and dv for seq_len values that are not multiples of 64.
  • Subsequently, address the bugs in dq, as calculation of dq is more complicated than dkdv.
  • Create a class class ScaleDotProductAttention(torch.autograd.Function) to manage context data generated during fwd (currently, a global variable GLOBAL_ATTENTION_M is being used to store M produced in forward).

Run test

Change head nums of q and kv, seq_len, and head_size in attention_test.py. Then:

cd path/to/flag_gems_project_root
python3 path/to/flag_gems_project_root/src/flag_gems/attention_test.py

@ch1y0q
Copy link
Contributor Author

ch1y0q commented Aug 14, 2025

Observations of dv

  1. When seq_len is a multiple of 64, dv is consistently correct across all heads.
  2. In MHA cases (query head count $hq$ == k/v head count $hkv$):
    • When seq_len is not a multiple of 64, head $H_0$ remains correct
    • Heads $H_1$ to $H_{hq-1}$ show incorrect initial tokens
    • The number of affected tokens equals offset (padding needed to make seq_len a multiple of 64)
    • Let's set offset = MOD(64-MOD(seq_len, 64),64) (Expression 1)
    • Mismatched elements: (hq - 1) × head_size × offset (Expression 2)
  3. In GQA cases ($hq$$hkv$, $hq$ divisible by $hkv$):
    • Head $H_0$ in dv shows errors
    • Mismatched elements exceed (hq - 1) × head_size × offset
    • Likely due to kv head $H_0$ serving multiple query heads
  4. The constant 64 in offset:
    • Independent of seq_len or head_size, which is verified with head sizes 64 and 128

Experiments

# q_num_head kv_num_head seq_len head_size elements mismatched elements h0 correct? offset=seq_len pad to multiply of 64, then -seq_len (kv_num_head - 1) * head_size (kv_num_head - 1) * head_size * offset
1 8 8 128 64 65536 0 y 0 448 0
2 8 8 127 64 65024 447 y 1 448 448
3 8 8 120 64 61440 3554 y 8 448 3584
4 8 8 64 64 32768 0 y 0 448 0
5 8 8 129 64 66048 27299 y 63 448 28224
6 8 8 255 64 130560 448 y 1 448 448
7 8 4 128 64 32768 0 y 0 192 0
8 8 4 127 64 32512 256 n 1 192 192
9 8 4 120 64 30720 2033 n 8 192 1536
10 8 4 129 64 33024 15684 n 63 192 12096
11 8 8 128 128 131072 0 y 0 896 0
12 8 8 127 128 130048 895 y 1 896 896
13 8 8 120 128 122880 7109 y 8 896 7168
14 8 8 64 128 65536 0 y 0 896 0

Explanation?

Let's look at MHA cases for sake of simplicity.
Head $H_0$ correctness implies _attn_bwd_preprocess and _attn_bwd_dkdv logic works. Subsequent heads' initial token errors appear specifically aligned with offset. The correct tokens beyond initial offset suggest pointer calculations (Q += adj, K += kv_adj, etc.) are globally valid.

    Q += adj
    K += kv_adj
    V += kv_adj
    DO += adj
    DQ += adj
    DK += adj
    DV += adj
    M += off_chz
    D += off_chz

However, only first tokens in subsequent heads in dv are incorrect. And the number of incorrect tokens happen to coincidence with offset. The rest tokens seem to be correct, which is unlikely due to miscalculation of Q, K, V, DO, DQ, ... pointers, or all tokens would be incorrect.

Is it possible that Triton pads unaligned sequences internally, so some elements are not loaded correctly?

Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _attn_fwd, please change the line tl.store(m_ptrs, m_i) with
tl.store(m_ptrs, m_i, mask=q_load_mask) to avoid race conditions.

@ch1y0q
Copy link
Contributor Author

ch1y0q commented Aug 20, 2025

bwd seems correct now, thanks @tongxin for finding the root issue that lies in the forward pass. Now we need to decide how to store m_i yielded in fwd. Should we make sdpa a torch.nn.Module and register both forward and backward hooks?

@tongxin
Copy link
Contributor

tongxin commented Aug 20, 2025

bwd seems correct now, thanks @tongxin for finding the root issue that lies in the forward pass. Now we need to decide how to store m_i yielded in fwd. Should we make sdpa a torch.nn.Module and register both forward and backward hooks?

I think that's a good idea.

@ch1y0q
Copy link
Contributor Author

ch1y0q commented Sep 2, 2025

src/flag_gems/ops/__init__.py:12:1: F401 'flag_gems.ops.attention.ScaleDotProductAttention' imported but unused
src/flag_gems/ops/__init__.py:12:1: F401 'flag_gems.ops.attention.scaled_dot_product_attention_backward' imported but unused

Currently, I make ScaleDotProductAttention as a torch.autograd.Function wrapper. scaled_dot_product_attention calls ScaleDotProductAttention.apply(), and ScaleDotProductAttention.backward() calls scaled_dot_product_attention_backward. For user, just call scaled_dot_product_attention(q, k, v, ...) to get a tensor triton_result, and call triton_result.backward(dout) to get dq, dk, dv. Is it ok?

  1. now bwd kernel assumes Q_CTX == KV_CTX. I will remove the constraint later. (Update: implemented.)

@ch1y0q ch1y0q marked this pull request as ready for review September 9, 2025 06:19
gems_assert_close(gems_result, torch_result, dtype)

# backward
dout = torch.randn_like(q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q should be ref_q here, or there will be device mismatch in CPU baseline tests. This is the reason op-test-quick-cpu failed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants