Skip to content

Commit

Permalink
enable more support
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 12, 2024
1 parent e2d9ffe commit 8572c1f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0) ||
((cudnn_runtime_version >= 90500) &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
Expand Down

0 comments on commit 8572c1f

Please sign in to comment.