-
Notifications
You must be signed in to change notification settings - Fork 345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX][Common] Support GQA #578
Conversation
/te-ci |
0f641e6
to
000548c
Compare
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix CI. Looks good to me. Thanks!
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
000548c
to
5372c5c
Compare
/te-ci |
26fa707
to
13b4cc0
Compare
/te-ci |
1 similar comment
/te-ci |
b538b4a
to
13b4cc0
Compare
/te-ci |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
acc308a
to
bbe7066
Compare
/te-ci |
Signed-off-by: Reese Wang <[email protected]>
/te-ci |
@cyanguwa @denera @mingxu1067, all unit tests passed. Could you help review again? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Support num_gqa_groups arguments Signed-off-by: Reese Wang <[email protected]> * Add GQA support on the JAX bridge code Signed-off-by: Reese Wang <[email protected]> * Fix the kv stride of the arbitrary backend Signed-off-by: Reese Wang <[email protected]> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by: Reese Wang <[email protected]> * Support unfused GQA Signed-off-by: Reese Wang <[email protected]> * Calculate seqlen before the primitive for the better perf Signed-off-by: Reese Wang <[email protected]> * Add GQA layer tests Signed-off-by: Reese Wang <[email protected]> * Apply code style checks for te_jax Signed-off-by: Reese Wang <[email protected]> * Apply code style checks for tests Signed-off-by: Reese Wang <[email protected]> * Add num_gqa_groups doc Signed-off-by: Reese Wang <[email protected]> * Refine the qkv_type Signed-off-by: Reese Wang <[email protected]> * Correct the variable naming Signed-off-by: Reese Wang <[email protected]> * Handle Max512 CAUSAL Signed-off-by: Reese Wang <[email protected]> * Add WAR for the latest jax image Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]>
num_gqa_groups
) for both fused attention and unfused attention implementation.kv_stride
of the flash attention