Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: causal prefix mask with adjusted tests #2

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

timt51
Copy link
Collaborator

@timt51 timt51 commented Nov 18, 2022

DO NOT LAND ON MAIN BRANCH

Implementation of causal prefix mask for cross attention (see Dao-AILab#20 (comment) for more info).

The original FlashAttention tests have been partially adjusted to take into account the new causal prefix masking scheme. The modified tests should correctly test that the output of flash_attn_unpadded_*_func, out, is correct, and the that gradients dq, dk, and dv are correct.

It has not been adjusted to properly test the output S_dmask (contains information about the attention values and dropout) because doing so requires figuring out the format of S_dmask (which is a non-standard format, see convert_flash_attn_S_to_softmax in the test file). This means that we cannot be sure about (1) whether the returned attention values are correct, and (2) whether causal prefix masking works with dropout. My guess is it does, assuming one can figure out how the data is formatted, but it hasn't been proven.

There may also be an effect on performance. I've seen the backward pass maybe taking longer... but hard to say.

@timt51 timt51 marked this pull request as draft November 18, 2022 02:35
@timt51 timt51 linked an issue Nov 18, 2022 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[1] Release causal prefix flashattn
1 participant