diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 53cf72e810..00bf02173b 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -211,8 +211,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - max_seqlen_q <= max_seqlen_kv && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} @@ -223,8 +223,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true;