Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Prepare cross flash attention #525

Merged
merged 5 commits into from
Dec 1, 2023

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Nov 20, 2023

This PR does

  • Prepare the arbitrary attention backend I/O (requires forward output in the backward)
  • Add bias support for the JAX bridge code

@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch 2 times, most recently from 019696b to aad1127 Compare November 20, 2023 09:53
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch from 5c2fd09 to eb1fde4 Compare November 21, 2023 10:05
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 marked this pull request as ready for review November 22, 2023 10:00
@zlsh80826 zlsh80826 requested a review from cyanguwa November 22, 2023 10:03
@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch from eb1fde4 to 7184ce9 Compare November 27, 2023 10:09
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 requested a review from cyanguwa November 28, 2023 14:47
@cyanguwa
Copy link
Collaborator

@zlsh80826 I'm seeing some failures with post scale bias and arbitrary backend in my CI for PR 497 (pipeline 75568647). I think your CI passed because those tests are skipped, which is because you don't have POST_SCALE_BIAS in your nvte_get_fused_attn_backend for flag_arb, which makes sense. But could you also please test your changes on top of mine see if the bias related tests can still pass? I tried the same configs on the PyTorch side and they seem to be fine. Thanks!

@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch from 4f7b8ce to b3ceefd Compare November 30, 2023 09:51
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch from b3ceefd to 1dd88d0 Compare November 30, 2023 14:58
@zlsh80826 zlsh80826 force-pushed the rewang/prepare-cross-flash-attn branch from 1dd88d0 to 81a1565 Compare December 1, 2023 05:20
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyanguwa cyanguwa merged commit 4d444db into NVIDIA:main Dec 1, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants