-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[C/PyTorch/Jax] Add support for more bias shapes (#677)
* added support for arbitrary bias shapes for fused_attn Signed-off-by: Alp Dener <[email protected]> * Fix linting Signed-off-by: Alp Dener <[email protected]> * Add b1ss/bhss/11ss bias shapes when not requiring dBias Signed-off-by: Charlene Yang <[email protected]> * fix lint Signed-off-by: Charlene Yang <[email protected]> * add bias_b/h to plan cache Signed-off-by: Charlene Yang <[email protected]> * fixed compile errors after PR653 merge Signed-off-by: Alp Dener <[email protected]> * updated JAX unittests for new bias shapes Signed-off-by: Alp Dener <[email protected]> * fixed mismatched mask type checking Signed-off-by: Alp Dener <[email protected]> * corrected skip condition Signed-off-by: Alp Dener <[email protected]> * fix selection logic for A100s Signed-off-by: Charlene Yang <[email protected]> * corrected skip checks for bias shapes Signed-off-by: Alp Dener <[email protected]> * resolved test issues but neginf with float16 is still problematic with JAX Signed-off-by: Alp Dener <[email protected]> * new bias shapes passing TE JAX CI for seqlen <= 512, seq_q == seq_kv and h_q == h_kv conditions Signed-off-by: Alp Dener <[email protected]> * TE/JAX fused attn tests for new bias shapes passing with neg_inf=-2**27 for Bfloat16 and -2**15 for Float16 Signed-off-by: Alp Dener <[email protected]> * code style fixes and test parameter ID cleanup Signed-off-by: Alp Dener <[email protected]> * fixed incorrect skip condition for backward fused attn test Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: Alp Dener <[email protected]>
- Loading branch information
Showing
9 changed files
with
508 additions
and
266 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.