File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -423,7 +423,8 @@ def test_contex_parallel_self_attn(
423
423
num_kv_heads = num_head // kv_groups
424
424
scaling_factor = 1.0 / np .sqrt (num_head )
425
425
426
- if not is_fused_attn_kernel_available (
426
+ def check_has_backend_for_mask (mask_type ):
427
+ return is_fused_attn_kernel_available (
427
428
dtype ,
428
429
dtype ,
429
430
qkv_layout ,
@@ -435,9 +436,15 @@ def test_contex_parallel_self_attn(
435
436
seqlen ,
436
437
seqlen ,
437
438
hidden ,
438
- None , # no window
439
- ):
440
- pytest .skip (f"No FusedAttn backend found" )
439
+ None ) # no SWA for CP
440
+
441
+ # For causal masking we depend on having bottom right support also.
442
+ has_backend = check_has_backend_for_mask (attn_mask_type )
443
+ if mask == AttnMaskType .CAUSAL_MASK_MASK :
444
+ has_backend &= check_has_backend_for_mask (AttnMaskType .CAUSAL_BOTTOM_RIGHT_MASK )
445
+
446
+ if not has_backend
447
+ pytest .skip (f"No FusedAttn backend found { cp_size = } { attn_mask_type = } ." )
441
448
442
449
if dp_size > 1 and batch % dp_size != 0 :
443
450
pytest .skip (f"Skipping { batch = } not a multiple of { dp_size = } " )
You can’t perform that action at this time.
0 commit comments