@@ -237,11 +237,12 @@ def test_dot_product_attention(
237
237
tols = dict (atol = 1.5e-2 , rtol = 1.5e-2 )
238
238
config = model_configs [model ]
239
239
is_mla = config .head_dim_qk != config .head_dim_v
240
+ is_mqa_gqa = config .num_heads != config .num_gqa_groups
240
241
if qkv_layout is None :
241
242
if config .attn_type == "self" :
242
- qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
243
+ qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
243
244
else :
244
- qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
245
+ qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
245
246
if "3" in qkv_layout and config .attn_type == "cross" :
246
247
pytest .skip ("No need to test this layout for cross attention" )
247
248
@@ -258,7 +259,8 @@ def test_dot_product_attention(
258
259
pad_between_seqs = pad_between_seqs ,
259
260
)
260
261
flash_attn_supported , fused_attn_supported , unfused_attn_supported = available_backends
261
- unfused_attn_supported = False
262
+ if swa :
263
+ unfused_attn_supported = False
262
264
print (flash_attn_supported , fused_attn_supported , unfused_attn_supported )
263
265
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
264
266
# mannually pads and unpads the input and output of FlashAttention for testing purposes
@@ -533,18 +535,18 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
533
535
534
536
model_configs_swa = {
535
537
# test: b, h, hg, d, sq, skv, p, mask, bias
536
- # "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
537
- # "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
538
- # "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
539
- # "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
540
- # "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
541
- # "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
542
- # "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
543
- # "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
544
- # "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
545
- # "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
546
- # "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
547
- # "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
538
+ "swa_1_0" : ModelConfig (4 , 16 , 16 , 64 , 128 , 128 , 0.0 , "no_mask" , "no_bias" ),
539
+ "swa_1_1" : ModelConfig (2 , 16 , 16 , 64 , 128 , 256 , 0.0 , "no_mask" , "no_bias" ),
540
+ "swa_1_2" : ModelConfig (4 , 24 , 24 , 128 , 2048 , 2048 , 0.0 , "no_mask" , "no_bias" ),
541
+ "swa_1_3" : ModelConfig (2 , 24 , 24 , 128 , 2048 , 4096 , 0.0 , "no_mask" , "no_bias" ),
542
+ "swa_2_0" : ModelConfig (4 , 16 , 16 , 64 , 128 , 128 , 0.0 , "causal" , "no_bias" ),
543
+ "swa_2_1" : ModelConfig (2 , 16 , 16 , 64 , 128 , 256 , 0.0 , "causal" , "no_bias" ),
544
+ "swa_2_2" : ModelConfig (4 , 24 , 24 , 128 , 2048 , 2048 , 0.0 , "causal" , "no_bias" ),
545
+ "swa_2_3" : ModelConfig (2 , 24 , 24 , 128 , 2048 , 4096 , 0.0 , "causal" , "no_bias" ),
546
+ "swa_3_0" : ModelConfig (4 , 16 , 16 , 64 , 128 , 128 , 0.0 , "causal_bottom_right" , "no_bias" ),
547
+ "swa_3_1" : ModelConfig (2 , 16 , 16 , 64 , 128 , 256 , 0.0 , "causal_bottom_right" , "no_bias" ),
548
+ "swa_3_2" : ModelConfig (4 , 24 , 24 , 128 , 2048 , 2048 , 0.0 , "causal_bottom_right" , "no_bias" ),
549
+ "swa_3_3" : ModelConfig (2 , 24 , 24 , 128 , 2048 , 4096 , 0.0 , "causal_bottom_right" , "no_bias" ),
548
550
"swa_4_0" : ModelConfig (4 , 24 , 4 , 128 , 2048 , 2048 , 0.0 , "padding_causal" , "no_bias" ),
549
551
"swa_4_1" : ModelConfig (2 , 24 , 24 , 128 , 2048 , 4096 , 0.0 , "padding_causal" , "no_bias" ),
550
552
"swa_4_2" : ModelConfig (
@@ -562,9 +564,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
562
564
@pytest .mark .parametrize ("model" , model_configs_swa .keys ())
563
565
def test_dpa_sliding_window (dtype , model_configs , model ):
564
566
"""Test DotProductAttention module with sliding window attention"""
565
- test_dot_product_attention (
566
- dtype , model_configs , model , False , True , "bshd_bshd_bshd" , True , False
567
- )
567
+ test_dot_product_attention (dtype , model_configs , model , False , True , None , True , False )
568
568
569
569
570
570
model_configs_alibi_slopes = {
@@ -631,18 +631,18 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
631
631
qkv_layouts_thd = ["t3hd" , "th3d" , "thd_t2hd" , "thd_th2d" , "thd_thd_thd" ]
632
632
model_configs_layout_thd = {
633
633
# test: b, h, hg, d, sq, skv, p, mask, bias
634
- # "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
635
- # "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
636
- # "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
637
- # "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
638
- # "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
639
- # "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
640
- # "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
641
- # "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
642
- # "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
643
- # "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
644
- # "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
645
- # "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
634
+ "layout_0_1" : ModelConfig (3 , 16 , 4 , 64 , 128 , 128 , 0.0 , "padding" , "no_bias" ),
635
+ "layout_0_2" : ModelConfig (8 , 16 , 4 , 64 , 128 , 128 , 0.0 , "padding" , "no_bias" ),
636
+ "layout_0_3" : ModelConfig (1 , 16 , 16 , 64 , 128 , 128 , 0.0 , "padding_causal" , "no_bias" ),
637
+ "layout_0_4" : ModelConfig (8 , 16 , 16 , 64 , 128 , 128 , 0.0 , "padding_causal" , "no_bias" ),
638
+ "layout_1_1" : ModelConfig (1 , 16 , 16 , 64 , 2048 , 2048 , 0.0 , "padding" , "no_bias" ),
639
+ "layout_1_2" : ModelConfig (8 , 16 , 16 , 64 , 2048 , 2048 , 0.0 , "padding" , "no_bias" ),
640
+ "layout_1_3" : ModelConfig (1 , 16 , 1 , 64 , 2048 , 2048 , 0.0 , "padding_causal" , "no_bias" ),
641
+ "layout_1_4" : ModelConfig (8 , 16 , 1 , 64 , 2048 , 2048 , 0.0 , "padding_causal" , "no_bias" ),
642
+ "layout_2_1" : ModelConfig (1 , 16 , 16 , 128 , 128 , 128 , 0.0 , "padding" , "no_bias" ),
643
+ "layout_2_2" : ModelConfig (1 , 16 , 16 , 64 , 128 , 256 , 0.0 , "padding" , "no_bias" ),
644
+ "layout_2_3" : ModelConfig (1 , 16 , 16 , 128 , 2048 , 2048 , 0.0 , "padding_causal" , "no_bias" ),
645
+ "layout_2_4" : ModelConfig (8 , 16 , 16 , 64 , 2048 , 4096 , 0.0 , "padding_causal" , "no_bias" ),
646
646
"layout_3_0" : ModelConfig (
647
647
2 ,
648
648
16 ,
@@ -680,10 +680,11 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
680
680
config = model_configs [model ]
681
681
if config .num_heads != config .num_gqa_groups and "3" in qkv_layout :
682
682
pytest .skip ("qkv_layout not applicable for MQA/GQA" )
683
- # pad_between_seqs = True
684
- # test_dot_product_attention(
685
- # dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
686
- # )
683
+ if config .window_size [0 ] == - 1 and config .window_size [1 ] in [- 1 , 0 ]:
684
+ pad_between_seqs = True
685
+ test_dot_product_attention (
686
+ dtype , model_configs , model , False , True , qkv_layout , False , pad_between_seqs
687
+ )
687
688
if get_cudnn_version () >= (9 , 3 , 0 ):
688
689
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
689
690
pad_between_seqs = False
0 commit comments