Skip to content

Commit 956570f

Browse files
committed
WIP: fix up swa
Signed-off-by: Charlene Yang <[email protected]>
1 parent 681ffbe commit 956570f

File tree

3 files changed

+97
-82
lines changed

3 files changed

+97
-82
lines changed

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,12 @@ def test_dot_product_attention(
237237
tols = dict(atol=1.5e-2, rtol=1.5e-2)
238238
config = model_configs[model]
239239
is_mla = config.head_dim_qk != config.head_dim_v
240+
is_mqa_gqa = config.num_heads != config.num_gqa_groups
240241
if qkv_layout is None:
241242
if config.attn_type == "self":
242-
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
243+
qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
243244
else:
244-
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
245+
qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
245246
if "3" in qkv_layout and config.attn_type == "cross":
246247
pytest.skip("No need to test this layout for cross attention")
247248

@@ -258,7 +259,8 @@ def test_dot_product_attention(
258259
pad_between_seqs=pad_between_seqs,
259260
)
260261
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
261-
unfused_attn_supported = False
262+
if swa:
263+
unfused_attn_supported = False
262264
print(flash_attn_supported, fused_attn_supported, unfused_attn_supported)
263265
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
264266
# mannually pads and unpads the input and output of FlashAttention for testing purposes
@@ -533,18 +535,18 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
533535

534536
model_configs_swa = {
535537
# test: b, h, hg, d, sq, skv, p, mask, bias
536-
# "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
537-
# "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
538-
# "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
539-
# "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
540-
# "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
541-
# "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
542-
# "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
543-
# "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
544-
# "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
545-
# "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
546-
# "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
547-
# "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
538+
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
539+
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
540+
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
541+
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
542+
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
543+
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
544+
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
545+
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
546+
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
547+
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
548+
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
549+
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
548550
"swa_4_0": ModelConfig(4, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
549551
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
550552
"swa_4_2": ModelConfig(
@@ -562,9 +564,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
562564
@pytest.mark.parametrize("model", model_configs_swa.keys())
563565
def test_dpa_sliding_window(dtype, model_configs, model):
564566
"""Test DotProductAttention module with sliding window attention"""
565-
test_dot_product_attention(
566-
dtype, model_configs, model, False, True, "bshd_bshd_bshd", True, False
567-
)
567+
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
568568

569569

570570
model_configs_alibi_slopes = {
@@ -631,18 +631,18 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
631631
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
632632
model_configs_layout_thd = {
633633
# test: b, h, hg, d, sq, skv, p, mask, bias
634-
# "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
635-
# "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
636-
# "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
637-
# "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
638-
# "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
639-
# "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
640-
# "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
641-
# "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
642-
# "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
643-
# "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
644-
# "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
645-
# "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
634+
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
635+
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
636+
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
637+
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
638+
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
639+
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
640+
"layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
641+
"layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
642+
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
643+
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
644+
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
645+
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
646646
"layout_3_0": ModelConfig(
647647
2,
648648
16,
@@ -680,10 +680,11 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
680680
config = model_configs[model]
681681
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
682682
pytest.skip("qkv_layout not applicable for MQA/GQA")
683-
# pad_between_seqs = True
684-
# test_dot_product_attention(
685-
# dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
686-
# )
683+
if config.window_size[0] == -1 and config.window_size[1] in [-1, 0]:
684+
pad_between_seqs = True
685+
test_dot_product_attention(
686+
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
687+
)
687688
if get_cudnn_version() >= (9, 3, 0):
688689
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
689690
pad_between_seqs = False

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
152152
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
153153
// bias type
154154
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
155-
((cudnn_runtime_version >= 8906) &&
155+
(cudnn_runtime_version >= 8906 &&
156156
(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
157157
(bias_type == NVTE_Bias_Type::NVTE_ALIBI &&
158158
attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK &&
@@ -161,38 +161,51 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
161161
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
162162
sm_arch_ >= 90) ||
163163
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
164-
((cudnn_runtime_version >= 90000) &&
164+
(cudnn_runtime_version >= 90000 &&
165165
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
166166
// mask type
167+
// pre-8.9.6: causal
167168
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
168-
((cudnn_runtime_version >= 8906) &&
169+
// 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal}
170+
(cudnn_runtime_version >= 8906 &&
171+
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
169172
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
170173
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
171174
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
172175
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
173-
((cudnn_runtime_version >= 90300) &&
176+
// 9.1: adds thd + {padding, padding_causal}
177+
(cudnn_runtime_version >= 90100 &&
178+
qkv_format == NVTE_QKV_Format::NVTE_THD &&
179+
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
180+
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
181+
// 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv)
182+
(cudnn_runtime_version >= 90300 &&
183+
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
174184
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
175-
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
176-
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
177-
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
178-
max_seqlen_q <= max_seqlen_kv && dropout == 0.0) ||
179-
((cudnn_runtime_version >= 90500) &&
185+
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
186+
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
187+
// 9.6: adds {bshd, sbhd} + causal_bottom_right + cross-attn (sq > skv)
188+
// and thd + padding_causal_bottom_right
189+
(cudnn_runtime_version >= 90600 &&
180190
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
181191
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
192+
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
182193
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) &&
183194
// bias + mask combination
184195
(!(cudnn_runtime_version >= 8906 &&
185196
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
186197
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
187198
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
188199
// qkv format
189-
((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
200+
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
190201
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
191202
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
192-
(cudnn_runtime_version >= 90600)))) &&
203+
cudnn_runtime_version >= 90600))) &&
193204
// sliding window
205+
// pre-9.2: full attn, causal
194206
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
195207
(window_size_right == -1 || window_size_right == 0)) ||
208+
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
196209
(cudnn_runtime_version >= 90200 &&
197210
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
198211
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
@@ -202,13 +215,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
202215
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
203216
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
204217
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) ||
205-
(cudnn_runtime_version >= 90500 &&
206-
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
207-
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
208-
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
209-
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
210-
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
211-
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))) &&
218+
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
219+
(cudnn_runtime_version >= 90600 &&
220+
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
221+
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
222+
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
223+
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
224+
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
225+
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
226+
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) &&
212227
// check 64-bit ragged offset support
213228
(supported_ragged_offset_size)) {
214229
flag_arb = true;

transformer_engine/pytorch/attention.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,7 @@ def get_attention_backend(
618618
# self-attention | | All
619619
# cross-attention | | FusedAttention, UnfusedDotProductAttention
620620
# causal_bottom_right | None | All
621-
# padding_causal_bottom_right | Same as "padding" |
622-
# self-attention | | All
623-
# cross-attention | | FlashAttention, UnfusedDotProductAttention
621+
# padding_causal_bottom_right | Same as "padding" | All
624622
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
625623
# | [b, h, sq, skv] |
626624
if attn_mask_type == "arbitrary":
@@ -691,37 +689,38 @@ def get_attention_backend(
691689
window_size = check_set_window_size(attn_mask_type, window_size)
692690
else:
693691
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
694-
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
695-
logger.debug(
696-
"Disabling FusedAttention as it does not support sliding window attention"
697-
" for FP8"
698-
)
699-
use_fused_attention = False
700-
elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
701-
logger.debug(
702-
"Disabling FusedAttention as it only supports sliding window attention "
703-
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
704-
)
705-
use_fused_attention = False
706-
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
707-
"no_mask",
708-
"padding",
709-
"causal_bottom_right",
710-
"padding_causal_bottom_right",
711-
]:
692+
#if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
693+
# logger.debug(
694+
# "Disabling FusedAttention as it does not support sliding window attention"
695+
# " for FP8"
696+
# )
697+
# use_fused_attention = False
698+
#elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
699+
if attention_dropout != 0.0:
712700
logger.debug(
713701
"Disabling FusedAttention as it does not support sliding window attention "
714-
"with attn_mask_type = %s for cross-attention",
715-
attn_mask_type,
702+
"with dropout"
716703
)
717704
use_fused_attention = False
718-
# elif "padding" in attn_mask_type:
719-
# logger.debug(
720-
# "Disabling FusedAttention as it does not support sliding window attention "
721-
# "with attn_mask_type = %s",
722-
# attn_mask_type,
723-
# )
724-
# use_fused_attention = False
705+
#elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
706+
# "no_mask",
707+
# "padding",
708+
# "causal_bottom_right",
709+
# "padding_causal_bottom_right",
710+
#]:
711+
# logger.debug(
712+
# "Disabling FusedAttention as it does not support sliding window attention "
713+
# "with attn_mask_type = %s for cross-attention",
714+
# attn_mask_type,
715+
# )
716+
# use_fused_attention = False
717+
#elif "padding" in attn_mask_type:
718+
# logger.debug(
719+
# "Disabling FusedAttention as it does not support sliding window attention "
720+
# "with attn_mask_type = %s",
721+
# attn_mask_type,
722+
# )
723+
# use_fused_attention = False
725724
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
726725
if _use_flash_attn_3:
727726
logger.debug(

0 commit comments

Comments
 (0)