Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX][Common] Support GQA #578

Merged
merged 14 commits into from
Jan 16, 2024
Merged

[JAX][Common] Support GQA #578

merged 14 commits into from
Jan 16, 2024

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Dec 26, 2023

  • Support GQA/MQA (num_gqa_groups) for both fused attention and unfused attention implementation.
  • Fix the kv_stride of the flash attention
  • Refactor fused attention test and add GQA tests
  • Calculate the seqlen before the primitive for the better perf (avoid to recompute it again in bwd)

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826 zlsh80826 force-pushed the rewang/gqa-clean branch 2 times, most recently from 0f641e6 to 000548c Compare December 27, 2023 04:24
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 marked this pull request as ready for review December 27, 2023 04:50
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix CI. Looks good to me. Thanks!

transformer_engine/jax/flax/transformer.py Outdated Show resolved Hide resolved
tests/jax/utils.py Show resolved Hide resolved
@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

/te-ci

1 similar comment
@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826 zlsh80826 requested a review from cyanguwa January 15, 2024 16:38
@zlsh80826
Copy link
Collaborator Author

@cyanguwa @denera @mingxu1067, all unit tests passed. Could you help review again? Thanks

Copy link
Collaborator

@mingxu1067 mingxu1067 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyanguwa cyanguwa merged commit 8f6c524 into NVIDIA:main Jan 16, 2024
28 checks passed
Wong4j pushed a commit to Wong4j/TransformerEngine that referenced this pull request Jan 22, 2024
* Support num_gqa_groups arguments

Signed-off-by: Reese Wang <[email protected]>

* Add GQA support on the JAX bridge code

Signed-off-by: Reese Wang <[email protected]>

* Fix the kv stride of the arbitrary backend

Signed-off-by: Reese Wang <[email protected]>

* Complete rewrite fused attention tests and add GQA coverage

Signed-off-by: Reese Wang <[email protected]>

* Support unfused GQA

Signed-off-by: Reese Wang <[email protected]>

* Calculate seqlen before the primitive for the better perf

Signed-off-by: Reese Wang <[email protected]>

* Add GQA layer tests

Signed-off-by: Reese Wang <[email protected]>

* Apply code style checks for te_jax

Signed-off-by: Reese Wang <[email protected]>

* Apply code style checks for tests

Signed-off-by: Reese Wang <[email protected]>

* Add num_gqa_groups doc

Signed-off-by: Reese Wang <[email protected]>

* Refine the qkv_type

Signed-off-by: Reese Wang <[email protected]>

* Correct the variable naming

Signed-off-by: Reese Wang <[email protected]>

* Handle Max512 CAUSAL

Signed-off-by: Reese Wang <[email protected]>

* Add WAR for the latest jax image

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants