From 9a09edb6b0497b5cbc183b043ae4eae1b7e94f05 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:22:51 -0800 Subject: [PATCH] add left_bound/right_bound Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 2 + .../common/fused_attn/fused_attn.cpp | 68 ++++---- .../fused_attn_f16_arbitrary_seqlen.cu | 70 +++++---- .../fused_attn_f16_arbitrary_seqlen.h | 12 +- .../common/fused_attn/fused_attn_fp8.cu | 36 ++--- .../common/fused_attn/fused_attn_fp8.h | 12 +- transformer_engine/common/fused_attn/utils.h | 5 +- .../include/transformer_engine/fused_attn.h | 18 ++- .../jax/csrc/extensions/attention.cpp | 24 +-- transformer_engine/pytorch/attention.py | 146 +++++++++++------- .../pytorch/cpp_extensions/fused_attn.py | 30 ++++ transformer_engine/pytorch/csrc/extensions.h | 12 +- .../pytorch/csrc/extensions/attention.cu | 36 ++--- transformer_engine/pytorch/transformer.py | 40 +++++ 14 files changed, 319 insertions(+), 192 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 3054eaa6cd..e24f82ab63 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -350,6 +350,8 @@ def test_dot_product_attention( torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") + torch.save(fused_attn_fwd, 'fused_attn_fwd.pt') + torch.save(flash_attn_fwd, 'flash_attn_fwd.pt') torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index adfe3acae1..0900917557 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,12 +183,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.6: adds {bshd, sbhd} + causal_bottom_right + cross-attn (sq > skv) - // and thd + padding_causal_bottom_right + // 9.6: adds thd + padding_causal_bottom_right (cudnn_runtime_version >= 90600 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && @@ -206,23 +204,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || - // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + // 9.6: SWA (left, right) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && + (window_size_left == -1 || window_size_left >= 0) && + (window_size_right == -1 || window_size_right >= 0) && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; @@ -272,7 +272,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -324,7 +324,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, + attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else @@ -334,7 +334,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_QKV, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -352,7 +352,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -414,7 +414,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else @@ -429,7 +429,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, + attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_QKV, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQKV, input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else @@ -448,7 +448,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; @@ -507,7 +507,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); @@ -519,7 +519,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); @@ -536,7 +536,7 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; @@ -607,7 +607,7 @@ void nvte_fused_attn_bwd_kvpacked( } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); @@ -623,7 +623,7 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_KV, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -643,7 +643,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; @@ -695,7 +695,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); @@ -706,7 +706,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else @@ -726,7 +726,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; @@ -791,7 +791,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); @@ -807,7 +807,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index bd71e4edae..bbcbd6fc1a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,7 +53,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, @@ -64,12 +64,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || + bool is_causal_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + if (is_causal_bottom_right && s_q == s_kv) { is_causal = true; - is_bottom_right = false; + is_causal_bottom_right = false; } + if (is_causal || is_causal_bottom_right) { + window_size_right = 0; + } + bottom_right_diagonal = is_causal_bottom_right && !is_causal && is_causal_bottom_right; bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); @@ -106,7 +110,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bias_type, mask_type, window_size_left, - window_size_right, + window_size_right, bottom_right_diagonal, true, tensorType, tensorType}; @@ -222,12 +226,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) - .set_causal_mask(is_causal) - .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_sliding_window_length(window_size_left + 1); + sdpa_options.set_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_right_bound(window_size_right); } sdpa_options.set_alibi_mask(is_alibi); @@ -432,7 +439,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, @@ -445,12 +452,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); - bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || + bool is_causal_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + if (is_causal_bottom_right && s_q == s_kv) { is_causal = true; - is_bottom_right = false; + is_causal_bottom_right = false; } + if (is_causal || is_causal_bottom_right) { + window_size_right = 0; + } + bottom_right_diagonal = is_causal_bottom_right && !is_causal && is_causal_bottom_right; bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); @@ -491,7 +502,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bias_type, mask_type, window_size_left, - window_size_right, + window_size_right, bottom_right_diagonal, deterministic, tensorType, tensorType}; @@ -657,16 +668,19 @@ void fused_attn_arbitrary_seqlen_bwd_impl( fe::graph::SDPA_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_backward_attributes() .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); } + fe::DiagonalAlignment_t const& diagonal_alignment = bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_sliding_window_length(window_size_left + 1); + sdpa_backward_options.set_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_right_bound(window_size_right); } if (cudnn_runtime_version >= 90000) { @@ -889,7 +903,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -989,7 +1003,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1012,7 +1026,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, @@ -1073,7 +1087,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, + bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, @@ -1098,7 +1112,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, @@ -1205,7 +1219,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1229,7 +1243,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, @@ -1296,7 +1310,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), @@ -1322,7 +1336,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, @@ -1419,7 +1433,7 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1444,7 +1458,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, @@ -1498,7 +1512,7 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 3a1216f891..0f201e1165 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -22,14 +22,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, @@ -40,7 +40,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, @@ -50,7 +50,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, @@ -63,7 +63,7 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, @@ -74,7 +74,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f8fe458219..9cead42ff1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,7 +1652,7 @@ void fused_attn_fp8_bwd_impl( void fused_attn_fp8_fwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, @@ -1688,9 +1688,7 @@ void fused_attn_fp8_fwd_impl_v1( dropout_probability, layout, bias_type, - mask_type, - 0, - 0, + mask_type, window_size_left, window_size_right, bottom_right_diagonal, true, fwd_tensor_type, fwd_tensor_type}; @@ -1952,7 +1950,7 @@ void fused_attn_fp8_fwd_impl_v1( void fused_attn_fp8_bwd_impl_v1( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, @@ -1992,9 +1990,7 @@ void fused_attn_fp8_bwd_impl_v1( dropout_probability, layout, bias_type, - mask_type, - 0, - 0, + mask_type, window_size_left, window_size_right, bottom_right_diagonal, false, fwd_tensor_type, bwd_tensor_type}; @@ -2352,7 +2348,7 @@ void fused_attn_fp8_bwd_impl_v1( void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, @@ -2422,7 +2418,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, @@ -2453,7 +2449,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, @@ -2514,7 +2510,7 @@ void fused_attn_fp8_bwd_qkvpacked( if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, @@ -2551,7 +2547,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, @@ -2623,7 +2619,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, @@ -2656,7 +2652,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num void fused_attn_fp8_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, @@ -2720,7 +2716,7 @@ void fused_attn_fp8_bwd_kvpacked( if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, @@ -2757,7 +2753,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, @@ -2821,7 +2817,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, @@ -2854,7 +2850,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ, @@ -2911,7 +2907,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 55830d3cda..4d896a7384 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -17,7 +17,7 @@ namespace transformer_engine { void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, @@ -26,7 +26,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, @@ -37,7 +37,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, @@ -47,7 +47,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num void fused_attn_fp8_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, @@ -59,7 +59,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, @@ -69,7 +69,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, const Tensor *output_dQ, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f790d3b567..8dfa97b63e 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -103,17 +103,18 @@ struct FADescriptor_v1 { NVTE_Mask_Type mask_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t fwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, - dropoutProbability, layout, mask_type, window_size_left, window_size_right, + dropoutProbability, layout, mask_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, + rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ae08f2a4aa..0e560e3b9d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -205,6 +205,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -214,7 +215,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. @@ -259,6 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -270,7 +272,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. @@ -325,6 +327,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -337,7 +340,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. @@ -388,6 +391,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -399,7 +403,7 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. @@ -458,6 +462,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -469,7 +474,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. @@ -525,6 +530,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_mask_type Attention mask type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. @@ -538,7 +544,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, NVTETensor workspace, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 4bde10fc46..cfb3d47136 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -159,14 +159,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + window_size_right, True, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), + bias_type, mask_type, window_size_left, window_size_right, True, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( @@ -175,7 +175,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + window_size_right, True, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -260,7 +260,7 @@ static void FusedAttnForwardImpl( q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + window_size_right, True, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; @@ -271,7 +271,7 @@ static void FusedAttnForwardImpl( &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + bias_type, mask_type, window_size_left, window_size_right, True, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -285,7 +285,7 @@ static void FusedAttnForwardImpl( q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + window_size_left, window_size_right, True, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -463,7 +463,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, + bias_type, mask_type, window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( @@ -474,7 +474,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), + window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), @@ -486,7 +486,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, + window_size_left, window_size_right, True, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); @@ -543,7 +543,7 @@ static void FusedAttnBackwardImpl( &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, + bias_type, mask_type, window_size_left, window_size_right, True, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -563,7 +563,7 @@ static void FusedAttnBackwardImpl( &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, True, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -589,7 +589,7 @@ static void FusedAttnBackwardImpl( kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); + window_size_right, True, deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ad036386e3..576b4052ca 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -256,6 +256,9 @@ class AttentionParams: Attention bias shape, {`1hss`, `b1ss`, `bhss`}. core_attention_bias_requires_grad: bool, default = `True` Whether attention bias requires gradient. + bottom_right_diagonal: bool, default = `True` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. pad_between_seqs: bool, default = `False` Whether there is padding between sequences in a batch. This only applies to `qkv_format=thd`. @@ -289,6 +292,7 @@ class AttentionParams: core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" core_attention_bias_requires_grad: bool = True + bottom_right_diagonal: bool = True pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False @@ -303,7 +307,7 @@ class AttentionParams: "_alibi_slopes": None, "_max_seqlen_q": None, "_max_seqlen_kv": None, - "_bottom_right_alignment": True, + "_bottom_right_diagonal": True, "_alibi_bias": None, "_alibi_slopes_require_update": False, "_alibi_bias_require_update": False, @@ -358,6 +362,7 @@ def get_attention_backend( core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad + bottom_right_diagonal = attention_params.bottom_right_diagonal pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel @@ -679,54 +684,27 @@ def get_attention_backend( _use_flash_attn_3 = False # Filter: Sliding window attention - # backend | window_size | diagonal alignment + # backend | window_size (left, right) | diagonal alignment # --------------------------------------------------------------------------------- - # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; - # | | converts window_size to an 'arbitrary' mask + # FlashAttention | (-1 or >=0, -1 or >=0) | bottom right + # FusedAttention | (-1 or >=0, -1 or >=0) | top left and bottom right + # UnfusedDotProductAttention | (-1 or >=0, -1 or >=0) | top left and bottom right; + # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - # if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - # logger.debug( - # "Disabling FusedAttention as it does not support sliding window attention" - # " for FP8" - # ) - # use_fused_attention = False - # elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd": - if attention_dropout != 0.0: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with dropout" - ) - use_fused_attention = False - # elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ - # "no_mask", - # "padding", - # "causal_bottom_right", - # "padding_causal_bottom_right", - # ]: - # logger.debug( - # "Disabling FusedAttention as it does not support sliding window attention " - # "with attn_mask_type = %s for cross-attention", - # attn_mask_type, - # ) - # use_fused_attention = False - # elif "padding" in attn_mask_type: - # logger.debug( - # "Disabling FusedAttention as it does not support sliding window attention " - # "with attn_mask_type = %s", - # attn_mask_type, - # ) - # use_fused_attention = False - if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if _use_flash_attn_3: - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with dropout" + ) + use_fused_attention = False + if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if _use_flash_attn_3: + if not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug("Disabling FlashAttention 3 as it only supports sliding window with bottom right diagonal alignment for cross-attention") _use_flash_attn_3 = False + if not _use_flash_attn_3: if not _flash_attn_is_installed: _flash_attn_version_required = PkgVersion("2.3") elif not _flash_attn_2_3_plus: @@ -734,6 +712,9 @@ def get_attention_backend( "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" ) use_flash_attention = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug("Disabling FlashAttention as it only supports sliding window with bottom right diagonal alignment for cross-attention") + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -753,6 +734,9 @@ def get_attention_backend( elif not _flash_attn_2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug("Disabling FlashAttention as it only supports ALiBi with bottom right diagonal alignment for cross-attention") + use_flash_attention = False if use_flash_attention and ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -1089,7 +1073,7 @@ def get_alibi( actual_seqlens_kv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias_dtype: Optional[torch.dtype] = None, - bottom_right_alignment: bool = True, + bottom_right_diagonal: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Parameters @@ -1108,7 +1092,7 @@ def get_alibi( Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. bias_dtype: Optional[torch.dtype], default = `None` Dtype of the generated ALiBi bias. If None, use torch.float32. - bottom_right_alignment: bool, default = `True` + bottom_right_diagonal: bool, default = `True` Whether to align the diagonal of the ALiBi bias to the bottom right corner of the matrix (`True`) or top left (`False`). @@ -1157,12 +1141,12 @@ def get_alibi( 1, 1, 1, max_seqlen_kv ) if actual_seqlens_q is None and actual_seqlens_kv is None: - if bottom_right_alignment: + if bottom_right_diagonal: bias = bias + max_seqlen_kv - max_seqlen_q elif actual_seqlens_q is not None and actual_seqlens_kv is not None: batch_size = actual_seqlens_q.shape[0] bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - if bottom_right_alignment: + if bottom_right_diagonal: bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) else: assert ( @@ -1171,7 +1155,7 @@ def get_alibi( bias = bias.abs().mul(-1) bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv - _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment + _alibi_cache["_bottom_right_diagonal"] = bottom_right_diagonal bias_dtype = torch.float32 if bias_dtype is None else bias_dtype _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") _alibi_cache["_alibi_bias_require_update"] = False @@ -4735,6 +4719,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -4874,7 +4859,7 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_diagonal=bottom_right_diagonal, ) matmul_result = torch.baddbmm( matmul_result, @@ -6418,6 +6403,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -6506,6 +6492,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, ) if is_output_fp8: @@ -6637,6 +6624,7 @@ def forward( attn_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, rng_gen, ) out_save = out_ret @@ -6682,6 +6670,7 @@ def forward( ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -6801,6 +6790,7 @@ def backward(ctx, d_out): ctx.attn_bias_type, ctx.attn_mask_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, ) @@ -6926,6 +6916,7 @@ def backward(ctx, d_out): ctx.attn_bias_type, ctx.attn_mask_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, ) @@ -6959,6 +6950,7 @@ def backward(ctx, d_out): None, None, None, + None, ) # else, return (dqkv, dbias) return ( @@ -6989,6 +6981,7 @@ def backward(ctx, d_out): None, None, None, + None, ) @@ -7081,6 +7074,7 @@ def forward( fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, fast_zero_fill: bool = True, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, @@ -7248,6 +7242,7 @@ def forward( core_attention_bias_type, attn_mask_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, @@ -7335,6 +7330,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type: str, default = `self` type of attention, either "`self`" and "`cross`". layer_number: int, default = `None` @@ -7397,6 +7397,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -7421,6 +7422,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -7638,6 +7640,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, is_first_microbatch: Optional[bool] = None, @@ -7798,6 +7801,11 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. fast_zero_fill: bool, default = `True` Whether to use the fast path to set output tensors to 0 or not. inference_params: Optional[InferenceParams], default = `None` @@ -7889,6 +7897,12 @@ def forward( if window_size is None: window_size = self.window_size window_size = check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}: + bottom_right_diagonal = True if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -8060,7 +8074,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -8069,7 +8082,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_diagonal"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -8125,6 +8138,7 @@ def forward( core_attention_bias_requires_grad=( core_attention_bias.requires_grad if core_attention_bias is not None else False ), + bottom_right_diagonal=bottom_right_diagonal, pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, @@ -8209,7 +8223,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_diagonal=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -8230,6 +8244,7 @@ def forward( fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, @@ -8255,6 +8270,7 @@ def forward( fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, @@ -8293,6 +8309,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, ) return self.unfused_attention( query_layer, @@ -8306,6 +8323,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, ) raise ValueError("No dot product attention support for the provided inputs!") @@ -8362,6 +8380,11 @@ class MultiheadAttention(torch.nn.Module): map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. num_gqa_groups : int, default = `None` number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -8462,6 +8485,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -8492,6 +8516,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -8757,6 +8782,7 @@ def forward( core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + bottom_right_diagonal: Optional[bool] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, @@ -8826,6 +8852,11 @@ def forward( ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. cu_seqlens_q: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. @@ -8848,6 +8879,12 @@ def forward( if window_size is None: window_size = self.window_size window_size = check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}: + bottom_right_diagonal = True if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: @@ -9109,6 +9146,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, fast_zero_fill=fast_zero_fill, inference_params=inference_params, ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 1932e9feb2..18f6921575 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -117,6 +117,7 @@ def fused_attn_fwd_qkvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for packed QKV input. @@ -186,6 +187,9 @@ def fused_attn_fwd_qkvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -271,6 +275,7 @@ def fused_attn_fwd_qkvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens, qkv, qkv_dtype, @@ -324,6 +329,7 @@ def fused_attn_bwd_qkvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed QKV input. @@ -394,6 +400,9 @@ def fused_attn_bwd_qkvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -444,6 +453,7 @@ def fused_attn_bwd_qkvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens, qkv, @@ -500,6 +510,7 @@ def fused_attn_fwd_kvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for packed KV input. @@ -579,6 +590,9 @@ def fused_attn_fwd_kvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -665,6 +679,7 @@ def fused_attn_fwd_kvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -725,6 +740,7 @@ def fused_attn_bwd_kvpacked( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -806,6 +822,9 @@ def fused_attn_bwd_kvpacked( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -859,6 +878,7 @@ def fused_attn_bwd_kvpacked( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, @@ -919,6 +939,7 @@ def fused_attn_fwd( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, rng_gen: torch.Generator = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -1004,6 +1025,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -1090,6 +1114,7 @@ def fused_attn_fwd( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -1152,6 +1177,7 @@ def fused_attn_bwd( attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = True, deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -1238,6 +1264,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = True + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. @@ -1293,6 +1322,7 @@ def fused_attn_bwd( AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3b49ece4a3..6dcbbd708b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -48,7 +48,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, @@ -60,7 +60,7 @@ std::vector fused_attn_fwd_qkvpacked( std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, @@ -74,7 +74,7 @@ std::vector fused_attn_bwd_qkvpacked( std::vector fused_attn_fwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, @@ -90,7 +90,7 @@ std::vector fused_attn_fwd_kvpacked( std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, @@ -105,7 +105,7 @@ std::vector fused_attn_bwd_kvpacked( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, @@ -121,7 +121,7 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index d03a10ced3..b4c07dff75 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -84,7 +84,7 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, + const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, const int descale_QKV_offset, const c10::optional descale_S, const int descale_S_offset, @@ -200,7 +200,7 @@ std::vector fused_attn_fwd_qkvpacked( te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -240,7 +240,7 @@ std::vector fused_attn_fwd_qkvpacked( te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -252,7 +252,7 @@ std::vector fused_attn_fwd_qkvpacked( // fused attention BWD with packed QKV std::vector fused_attn_bwd_qkvpacked( size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, @@ -396,7 +396,7 @@ std::vector fused_attn_bwd_qkvpacked( te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); @@ -408,7 +408,7 @@ std::vector fused_attn_bwd_qkvpacked( te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); @@ -420,7 +420,7 @@ std::vector fused_attn_bwd_qkvpacked( std::vector fused_attn_fwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, @@ -537,7 +537,7 @@ std::vector fused_attn_fwd_kvpacked( te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors @@ -578,7 +578,7 @@ std::vector fused_attn_fwd_kvpacked( te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory @@ -592,7 +592,7 @@ std::vector fused_attn_fwd_kvpacked( std::vector fused_attn_bwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, @@ -752,7 +752,7 @@ std::vector fused_attn_bwd_kvpacked( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], + bias_type, attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace @@ -766,7 +766,7 @@ std::vector fused_attn_bwd_kvpacked( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], + bias_type, attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers @@ -779,7 +779,7 @@ std::vector fused_attn_bwd_kvpacked( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, + NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_q_padded, @@ -904,7 +904,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), + attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace and auxiliary output tensors @@ -946,7 +946,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), + attn_mask_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers, but not allocated memory @@ -960,7 +960,7 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, + const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, @@ -1199,7 +1199,7 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - window_size[0], window_size[1], deterministic, workspace.data(), + window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // allocate memory for workspace @@ -1213,7 +1213,7 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - window_size[0], window_size[1], deterministic, workspace.data(), + window_size[0], window_size[1], bottom_right_diagonal, deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..04a984e92b 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -147,11 +147,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by :attr:`window_size` in `forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = `no_mask` type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = `None` sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -247,8 +257,10 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: bool = None, enc_dec_attn_mask_type: str = "no_mask", enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: bool = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, params_dtype: Optional[torch.dtype] = None, @@ -282,10 +294,12 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = check_set_window_size(self_attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = check_set_window_size( enc_dec_attn_mask_type, enc_dec_window_size ) + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -530,10 +544,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -575,6 +591,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = `None` Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = `None` Output of the encoder block to be fed into the decoder block if using `layer_type="decoder"`. @@ -591,6 +612,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = `None` Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -649,6 +675,18 @@ def forward( if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}: + bottom_right_diagonal = True + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -692,6 +730,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=bottom_right_diagonal, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, @@ -723,6 +762,7 @@ def forward( core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, fast_zero_fill=fast_zero_fill, ) if self.apply_residual_connection_post_layernorm: