Skip to content

Commit

Permalink
Clean up checks in unit test.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Oct 30, 2024
1 parent 535af66 commit 0b302d4
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,21 +423,32 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP

# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand Down

0 comments on commit 0b302d4

Please sign in to comment.