-
Notifications
You must be signed in to change notification settings - Fork 144
[WIP] add scaled_dot_product_attention_backward
#898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Currently, TODO
Run testChange head nums of q and kv, seq_len, and head_size in
|
Observations of
|
# | q_num_head | kv_num_head | seq_len | head_size | elements | mismatched elements | h0 correct? | offset=seq_len pad to multiply of 64, then -seq_len | (kv_num_head - 1) * head_size | (kv_num_head - 1) * head_size * offset |
---|---|---|---|---|---|---|---|---|---|---|
1 | 8 | 8 | 128 | 64 | 65536 | 0 | y | 0 | 448 | 0 |
2 | 8 | 8 | 127 | 64 | 65024 | 447 | y | 1 | 448 | 448 |
3 | 8 | 8 | 120 | 64 | 61440 | 3554 | y | 8 | 448 | 3584 |
4 | 8 | 8 | 64 | 64 | 32768 | 0 | y | 0 | 448 | 0 |
5 | 8 | 8 | 129 | 64 | 66048 | 27299 | y | 63 | 448 | 28224 |
6 | 8 | 8 | 255 | 64 | 130560 | 448 | y | 1 | 448 | 448 |
7 | 8 | 4 | 128 | 64 | 32768 | 0 | y | 0 | 192 | 0 |
8 | 8 | 4 | 127 | 64 | 32512 | 256 | n | 1 | 192 | 192 |
9 | 8 | 4 | 120 | 64 | 30720 | 2033 | n | 8 | 192 | 1536 |
10 | 8 | 4 | 129 | 64 | 33024 | 15684 | n | 63 | 192 | 12096 |
11 | 8 | 8 | 128 | 128 | 131072 | 0 | y | 0 | 896 | 0 |
12 | 8 | 8 | 127 | 128 | 130048 | 895 | y | 1 | 896 | 896 |
13 | 8 | 8 | 120 | 128 | 122880 | 7109 | y | 8 | 896 | 7168 |
14 | 8 | 8 | 64 | 128 | 65536 | 0 | y | 0 | 896 | 0 |
Explanation?
Let's look at MHA cases for sake of simplicity.
Head _attn_bwd_preprocess
and _attn_bwd_dkdv
logic works. Subsequent heads' initial token errors appear specifically aligned with offset. The correct tokens beyond initial offset
suggest pointer calculations (Q += adj, K += kv_adj, etc.) are globally valid.
Q += adj
K += kv_adj
V += kv_adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
However, only first tokens in subsequent heads in dv
are incorrect. And the number of incorrect tokens happen to coincidence with offset
. The rest tokens seem to be correct, which is unlikely due to miscalculation of Q, K, V, DO, DQ, ...
pointers, or all tokens would be incorrect.
Is it possible that Triton pads unaligned sequences internally, so some elements are not loaded correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _attn_fwd
, please change the line tl.store(m_ptrs, m_i)
with
tl.store(m_ptrs, m_i, mask=q_load_mask)
to avoid race conditions.
bwd seems correct now, thanks @tongxin for finding the root issue that lies in the forward pass. Now we need to decide how to store |
I think that's a good idea. |
Currently, I make ScaleDotProductAttention as a
|
815dae8
to
c88d651
Compare
tests/test_attention_ops.py
Outdated
gems_assert_close(gems_result, torch_result, dtype) | ||
|
||
# backward | ||
dout = torch.randn_like(q) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q
should be ref_q
here, or there will be device mismatch in CPU baseline tests. This is the reason op-test-quick-cpu failed.
PR Category
Operator | OP Test
Type of Change
New Feature
Description
Add backward of SDPA, which is still buggy when
seq_len
is not multiply of 64.Issue
Progress
Performance