From 8572c1f4db62b5de3d36bd0029549bc5310e7eb5 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:49:52 -0800 Subject: [PATCH] enable more support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 2837afbd87..413c970eb0 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -177,9 +177,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0) || ((cudnn_runtime_version >= 90500) && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||