From 7b8ff67d3330003644175b6bba2bbb514e9ee1db Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 4 Dec 2023 17:54:44 +0000 Subject: [PATCH] fix distributed jax test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_distributed_fused_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 70e2765160..6c4a889f87 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -73,6 +73,9 @@ def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype @pytest.mark.parametrize('dtype', DTYPES) def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_bias_type, attn_mask_type, dtype): + # TODO (cyang): remove this when cudnn fe v1 has included it for dbias cases + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + dropout_prob = 0.0 is_training = True scaling_factor = 1.0