Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 14, 2024
1 parent 956570f commit 8d17e10
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
if config.window_size[0] == -1 and config.window_size[1] in [-1, 0]:
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
Expand Down
25 changes: 12 additions & 13 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(cudnn_runtime_version >= 90000 &&
(bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
// mask type
// pre-8.9.6: causal
// pre-8.9.6: causal
((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
// 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal}
// 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal}
(cudnn_runtime_version >= 8906 &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.1: adds thd + {padding, padding_causal}
(cudnn_runtime_version >= 90100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD &&
// 9.1: adds thd + {padding, padding_causal}
(cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
// 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv)
// 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv)
(cudnn_runtime_version >= 90300 &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.6: adds {bshd, sbhd} + causal_bottom_right + cross-attn (sq > skv)
// and thd + padding_causal_bottom_right
// 9.6: adds {bshd, sbhd} + causal_bottom_right + cross-attn (sq > skv)
// and thd + padding_causal_bottom_right
(cudnn_runtime_version >= 90600 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
Expand All @@ -202,10 +201,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600))) &&
// sliding window
// pre-9.2: full attn, causal
// pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,32 +689,32 @@ def get_attention_backend(
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
#if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
# if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
# logger.debug(
# "Disabling FusedAttention as it does not support sliding window attention"
# " for FP8"
# )
# use_fused_attention = False
#elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
# elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd":
if attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with dropout"
)
use_fused_attention = False
#elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
# elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
# "no_mask",
# "padding",
# "causal_bottom_right",
# "padding_causal_bottom_right",
#]:
# ]:
# logger.debug(
# "Disabling FusedAttention as it does not support sliding window attention "
# "with attn_mask_type = %s for cross-attention",
# attn_mask_type,
# )
# use_fused_attention = False
#elif "padding" in attn_mask_type:
# elif "padding" in attn_mask_type:
# logger.debug(
# "Disabling FusedAttention as it does not support sliding window attention "
# "with attn_mask_type = %s",
Expand Down

0 comments on commit 8d17e10

Please sign in to comment.