diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d34afb340c..4ef56b1b10 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -423,21 +423,32 @@ def test_contex_parallel_self_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) - if not is_fused_attn_kernel_available( - dtype, - dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_prob, - num_head, - num_kv_heads, - seqlen, - seqlen, - hidden, - None, # no window - ): - pytest.skip(f"No FusedAttn backend found") + def check_has_backend_for_mask(mask_type): + return is_fused_attn_kernel_available( + dtype, + dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_kv_heads, + seqlen, + seqlen, + hidden, + None, + ) # no SWA for CP + + # For causal masking we depend on having bottom right support also. + # The API does not check this and instead we rely on lower level checks to raise + # and exception if the step backend is not supported. This was a deliberate API + # decision to keep the CP size or flag out of the function. + has_backend = check_has_backend_for_mask(attn_mask_type) + if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: + has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) + + if not has_backend: + pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.") if dp_size > 1 and batch % dp_size != 0: pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")