Skip to content

Commit

Permalink
fix distributed jax test
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 4, 2023
1 parent 225b560 commit 7b8ff67
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype
@pytest.mark.parametrize('dtype', DTYPES)
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_bias_type, attn_mask_type, dtype):
# TODO (cyang): remove this when cudnn fe v1 has included it for dbias cases
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
Expand Down

0 comments on commit 7b8ff67

Please sign in to comment.