Skip to content

Commit

Permalink
[JAX] Fix JAX distributed unit tests (#521)
Browse files Browse the repository at this point in the history
* Remove assertion for NO_MASK

Signed-off-by: Reese Wang <[email protected]>

* Fix JAX distributed unit tests name

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Nov 20, 2023
1 parent 6159af4 commit ea43b18
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
2 changes: 1 addition & 1 deletion qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_custom_ops.py
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*

2 changes: 1 addition & 1 deletion qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'

pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
Expand Down
3 changes: 0 additions & 3 deletions transformer_engine/jax/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed
"""
Self fused attention wrapper
"""
assert attn_mask_type is not AttnMaskType.NO_MASK, \
"Currently not support AttnMaskType.NO_MASK."

output = _self_fused_attn(qkv,
bias,
mask,
Expand Down

0 comments on commit ea43b18

Please sign in to comment.