-
Notifications
You must be signed in to change notification settings - Fork 337
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Expose sliding window attn to TE-JAX API (#1205)
* Expose JAX sliding window attn API Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * No SWA in context parallel; fix RNG seed in test Signed-off-by: Hua Huang <[email protected]> * Handle SAW API discrepancy in cuDNN and Python Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add SAW API for flax, all tests passed Will update tests/jax/test_praxis_layers.py next Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_praxis_layers.py for SWA, test passed Signed-off-by: Hua Huang <[email protected]> * Use tuple window_size; update for PR #1212 Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add and adjust some pytest.skip Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revised following Reese Wang's comments Still need further debugging: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-NO_BIAS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-NO_BIAS] - AssertionError: FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError: These errors does not exist in the previous commit Signed-off-by: Hua Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix no-SWA test case errors in previous commit Signed-off-by: Hua Huang <[email protected]> * Add Padding mask w/ sliding windows sanity tests Signed-off-by: Reese Wang <[email protected]> * Use float32 for the reference code softmax calculation Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Hua Huang <[email protected]> Signed-off-by: Reese Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Reese Wang <[email protected]>
- Loading branch information
1 parent
5b6546c
commit 85e60e6
Showing
11 changed files
with
390 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.