Skip to content

Commit b0c5c06

Browse files
Clean up checks in unit test.
Signed-off-by: Michael Goldfarb <[email protected]>
1 parent c4d3749 commit b0c5c06

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ def test_contex_parallel_self_attn(
423423
num_kv_heads = num_head // kv_groups
424424
scaling_factor = 1.0 / np.sqrt(num_head)
425425

426-
if not is_fused_attn_kernel_available(
426+
def check_has_backend_for_mask(mask_type):
427+
return is_fused_attn_kernel_available(
427428
dtype,
428429
dtype,
429430
qkv_layout,
@@ -435,9 +436,15 @@ def test_contex_parallel_self_attn(
435436
seqlen,
436437
seqlen,
437438
hidden,
438-
None, # no window
439-
):
440-
pytest.skip(f"No FusedAttn backend found")
439+
None) # no SWA for CP
440+
441+
# For causal masking we depend on having bottom right support also.
442+
has_backend = check_has_backend_for_mask(attn_mask_type)
443+
if mask == AttnMaskType.CAUSAL_MASK_MASK:
444+
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)
445+
446+
if not has_backend
447+
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")
441448

442449
if dp_size > 1 and batch % dp_size != 0:
443450
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")

0 commit comments

Comments
 (0)