Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate cuDNN frontend v1 to fused attention #497

Merged
merged 79 commits into from
Dec 7, 2023

Conversation

cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Oct 30, 2023

This PR:

  • [C/PyTorch] integrates cuDNN frontend v1 to fused attention,
  • [C/PyTorch] adds support for padding, padding_causal masks, post_scale_bias, alibi biases, and MQA/GQA to fused attention,
  • [C/PyTorch] widens support for _qkvpacked APIs to h3d/3hd from qkv_interleaved, and support for _kvpacked APIs to hd_h2d/hd_2hd from kv_interleaved,
  • [C/PyTorch] removes all usage of qkv_interleaved, kv_interleaved, not_interleaved enums,
  • [PyTorch] fixes FlashAttention module for padding_causal mask,
  • [PyTorch] fixes fwd output shape for thd format in FlashAttention,
  • [PyTorch] adds alibi support for UnfusedDPA,
  • [C] changes the backend selection from prioritizing max512 to arbitrary_seqlen backend,
  • [PyTorch] changes the backend selection from prioritizing FlashAttention to FusedAttention arbitrary_seqlen backend on sm90,
  • [C] disables padding_causal mask for cross attention for max512 backend due to bugs,
  • [Pytorch] makes stats_lse contiguous for future context parallel implementations,
  • [Pytorch] completely rewrites unit tests in test_fused_attn.py for better coverage and efficiency.

setup.py Outdated Show resolved Hide resolved
@cyanguwa cyanguwa changed the title [Draft] Integrate cuDNN frontend v1 to fused attention Integrate cuDNN frontend v1 to fused attention Nov 10, 2023
@cyanguwa cyanguwa marked this pull request as ready for review November 10, 2023 23:10
@cyanguwa
Copy link
Collaborator Author

/te-ci

@cyanguwa
Copy link
Collaborator Author

/te-ci

1 similar comment
@cyanguwa
Copy link
Collaborator Author

/te-ci

@cyanguwa cyanguwa force-pushed the fused_attn/graph_api_v1 branch from 92d2d56 to 7e3d4fe Compare November 13, 2023 23:10
@cyanguwa cyanguwa force-pushed the fused_attn/graph_api_v1 branch from c4c7c9d to 998ccb0 Compare November 14, 2023 00:14
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci

Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci

@cyanguwa cyanguwa closed this Nov 14, 2023
@cyanguwa cyanguwa force-pushed the fused_attn/graph_api_v1 branch from a30f49e to 71e51ea Compare November 14, 2023 19:03
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa reopened this Nov 14, 2023
@cyanguwa
Copy link
Collaborator Author

/te-ci pytorch

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 4, 2023

/te-ci jax

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 4, 2023

Pipeline 11240636

@cyanguwa cyanguwa requested a review from zlsh80826 December 5, 2023 21:15
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 6, 2023

Pipeline 11267171

Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for making this awesome upgrade!

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Dec 7, 2023

Pipeline 11267171 is green. The last 5 commits are mostly to fix L1 tests, which are not critical. 95d0820 is an exception. It reduces the amount of prints when training.

@cyanguwa cyanguwa merged commit 32db392 into NVIDIA:main Dec 7, 2023
9 checks passed
@cyanguwa cyanguwa deleted the fused_attn/graph_api_v1 branch February 22, 2024 00:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants