Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add sliding window support to FlashAttention #551

Merged
merged 21 commits into from
Dec 16, 2023

Conversation

cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Dec 6, 2023

This PR only makes changes on the PyTorch side. It

  • integrates flash-attn 2.3+ sliding window attention to TransformerLayer, MultiHeadAttention, DotProductAttention and FlashAttention
  • adds unit tests to compare against UnfusedDotProductAttention arbitrary mask, generated based on the window size
  • adds a use_unfused_attention flag and exception when none of the three DPA backends are available
  • adds more determinism control in the backend selection, in particular, the filter that fused attention arbitrary backend is non-deterministic on non-sm90 architectures because it doesn't have a workspace optimization path

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 6, 2023

/te-ci

Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 6, 2023

/te-ci

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 6, 2023

/te-ci pytorch

Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa closed this Dec 7, 2023
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa reopened this Dec 8, 2023
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 8, 2023

/te-ci pytorch

1 similar comment
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 8, 2023

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 8, 2023

Pipeline 11331817

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 11, 2023

With newer cuDNN, pipeline 11412889 is green!

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@ptrendx ptrendx merged commit 27aa609 into NVIDIA:main Dec 16, 2023
20 checks passed
@cyanguwa cyanguwa deleted the fa/sliding_window branch February 21, 2024 23:59
@ashvinnihalani
Copy link

Also want to add a comment that what happens when we want to use grouped query attention with unfused attention. Right now it seems like it errors out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants