diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ff64f1de72..24d87c0416 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,6 +17,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | mkdir -p wheelhouse && \ @@ -41,6 +43,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ @@ -66,6 +70,8 @@ jobs: steps: - name: 'Checkout' uses: actions/checkout@v3 + with: + submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ diff --git a/.gitmodules b/.gitmodules index 85675ac0bc..21492db5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "3rdparty/googletest"] path = 3rdparty/googletest url = https://github.com/google/googletest.git +[submodule "3rdparty/cudnn-frontend"] + path = 3rdparty/cudnn-frontend + url = https://github.com/NVIDIA/cudnn-frontend.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend new file mode 160000 index 0000000000..e7f64390e9 --- /dev/null +++ b/3rdparty/cudnn-frontend @@ -0,0 +1 @@ +Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609 diff --git a/Acknowledgements.txt b/Acknowledgements.txt index 7eec81a9ce..ad11acc047 100644 --- a/Acknowledgements.txt +++ b/Acknowledgements.txt @@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +======================== +cudnn-frontend + +Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/docs/api/c/fused_attn.rst b/docs/api/c/fused_attn.rst new file mode 100644 index 0000000000..c2384b7e12 --- /dev/null +++ b/docs/api/c/fused_attn.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +fused_attn.h +============ + +.. doxygenfile:: fused_attn.h diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index 0f83b8dc02..f98a419088 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -17,6 +17,7 @@ directly from C/C++, without Python. activation.h cast.h gemm.h + fused_attn.h layer_norm.h softmax.h transformer_engine.h diff --git a/docs/installation.rst b/docs/installation.rst index 2614dd0477..07252eef59 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -14,6 +14,8 @@ Prerequisites 1. Linux x86_64 2. `CUDA 11.8 `__ 3. |driver link|_ supporting CUDA 11.8 or later. +4. `cuDNN 8 `__ or later. +5. For FP8 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 `__ or later. Transformer Engine in NGC Containers diff --git a/setup.py b/setup.py index cb0c37fe3a..b88e4fbcc4 100644 --- a/setup.py +++ b/setup.py @@ -105,6 +105,7 @@ def make_abs_path(l): include_dirs = [ "transformer_engine/common/include", "transformer_engine/pytorch/csrc", + "3rdparty/cudnn-frontend/include", ] if NVTE_WITH_USERBUFFERS: if MPI_HOME: diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 151eddb9f9..bbb25bb2fc 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -42,6 +42,7 @@ const std::string &typeName(DType type) { static const std::unordered_map name_map = { {DType::kByte, "byte"}, {DType::kInt32, "int32"}, + {DType::kInt64, "int64"}, {DType::kFloat32, "float32"}, {DType::kFloat16, "float16"}, {DType::kBFloat16, "bfloat16"}, diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index f35d494c8d..7278f1827b 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -44,6 +44,7 @@ struct BytesToType<8> { using byte = uint8_t; using int32 = int32_t; +using int64 = int64_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; @@ -54,6 +55,7 @@ template struct TypeInfo{ using types = std::tuple + $ +) + +target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + CUDNN::cudnn +) + diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c5bc6bb0f1..7b844540ae 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu activation/gelu.cu + fused_attn/fused_attn_fp8.cu + fused_attn/fused_attn.cpp + fused_attn/utils.cu gemm/cublaslt_gemm.cu layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart - CUDA::nvToolsExt) + CUDA::nvToolsExt + CUDNN::cudnn) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include") # Compiler options set_source_files_properties(fused_softmax/scaled_masked_softmax.cu diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp new file mode 100644 index 0000000000..17b6505038 --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -0,0 +1,232 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" +#include "fused_attn_fp8.h" + +// NVTE fused attention FWD FP8 with packed QKV +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens, + const NVTETensor rng_state, + size_t max_seqlen, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + // QKV shape is [total_seqs, 3, h, d] + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = input_QKV->data.shape[2]; + size_t d = input_QKV->data.shape[3]; + const DType QKV_type = input_QKV->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen <= 512)) { +#if (CUDNN_VERSION >= 8900) + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + // FP8 API doesn't use input_Bias, bias_type or attn_mask_type + fused_attn_fwd_fp8_qkvpacked( + b, max_seqlen, h, d, + is_training, attn_scale, dropout, qkv_layout, + input_QKV, input_output_S, output_O, + Aux_Output_Tensors, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); +#endif + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if (max_seqlen > 512) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention BWD FP8 with packed QKV +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQKV, + const NVTETensor cu_seqlens, + size_t max_seqlen, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_QKV = reinterpret_cast(QKV); + const Tensor *input_dBias = reinterpret_cast(dBias); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQKV = reinterpret_cast(dQKV); + Tensor *wkspace = reinterpret_cast(workspace); + + // QKV shape is [total_seqs, 3, h, d] + size_t b = input_cu_seqlens->data.shape[0] - 1; + size_t h = input_QKV->data.shape[2]; + size_t d = input_QKV->data.shape[3]; + const DType QKV_type = input_QKV->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen <= 512)) { +#if (CUDNN_VERSION >= 8900) + // Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + // FP8 API doesn't use input_dBias, bias_type or attn_mask_type + fused_attn_bwd_fp8_qkvpacked( + b, max_seqlen, h, d, + attn_scale, dropout, qkv_layout, + input_QKV, input_O, input_dO, + input_M, input_ZInv, + input_S, input_output_dP, + output_dQKV, + input_cu_seqlens, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); +#endif + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if (max_seqlen > 512) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention FWD FP8 with packed KV +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_rng_state = reinterpret_cast(rng_state); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_Bias = reinterpret_cast(Bias); + Tensor *input_output_S = reinterpret_cast(S); + Tensor *output_O = reinterpret_cast(O); + Tensor *wkspace = reinterpret_cast(workspace); + + // Q shape is [total_seqs, h, d] + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[1]; + size_t d = input_Q->data.shape[2]; + const DType QKV_type = input_Q->data.dtype; + + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} +// NVTE fused attention BWD FP8 with packed KV +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dKV, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_Q = reinterpret_cast(Q); + const Tensor *input_KV = reinterpret_cast(KV); + const Tensor *input_dBias = reinterpret_cast(dBias); + const Tensor *input_O = reinterpret_cast(O); + const Tensor *input_dO = reinterpret_cast(dO); + const Tensor *input_S = reinterpret_cast(S); + Tensor *input_output_dP = reinterpret_cast(dP); + Tensor *output_dQ = reinterpret_cast(dQ); + Tensor *output_dKV = reinterpret_cast(dKV); + Tensor *wkspace = reinterpret_cast(workspace); + + // Q shape is [total_seqs, h, d] + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h = input_Q->data.shape[1]; + size_t d = input_Q->data.shape[2]; + const DType QKV_type = input_Q->data.dtype; + if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); + } else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) + && (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { + NVTE_ERROR("TBD: No support for BF16/FP16 fused attention currently. \n"); + } else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { + NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); + } else { + NVTE_ERROR("Invalid combination of data type and sequence length! \n"); + } +} diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu new file mode 100644 index 0000000000..633f46c51f --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -0,0 +1,2138 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" +#include "fused_attn_fp8.h" + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +#if (CUDNN_VERSION >= 8900) +std::unordered_map tensor_name_to_uid = { + {"Q", 1}, + {"K", 2}, + {"V", 3}, + {"O", 4}, + {"S", 5}, + {"B", 6}, + {"DROPOUT_SCALE", 7}, + {"S_CONST", 8}, + {"MNK_OVERRIDE", 9}, + {"dQ", 11}, + {"dK", 12}, + {"dV", 13}, + {"dO", 14}, + {"MASK_VAL", 15}, + {"dS", 16}, + {"O_SEQLEN", 17}, + {"M", 18}, + {"Z", 19}, + {"descaleQ", 20}, + {"descaleK", 21}, + {"descaleV", 22}, + {"descaleS", 23}, + {"scaleS", 24}, + {"amaxS", 25}, + {"amaxO", 26}, + {"QKV_RAGGED", 27}, + {"O_RAGGED", 28}, + {"K_TRANSPOSE", 29}, + {"AttnScale", 30}, + {"scaleO", 31}, + {"Z_INV", 32}, + {"descaleO", 33}, + {"descaledO", 34}, + {"descaledS", 35}, + {"descaledQ", 36}, + {"descaledK", 37}, + {"descaledV", 38}, + {"scaledS", 39}, + {"scaledQ", 40}, + {"scaledK", 41}, + {"scaledV", 42}, + {"amaxdS", 43}, + {"amaxdQ", 44}, + {"amaxdK", 45}, + {"amaxdV", 46}, + {"V_TRANSPOSE", 47}, + {"AttnScale_dS_K", 48}, + {"AttnScale_dSTranspose_Q", 49}, + {"DROPOUT_SCALE_dOVt_OdO", 50}, + {"DROPOUT_OFFSET", 51}, + {"DROPOUT_SEED", 52}, + {"VIRTUAL", 80} +}; + +bool allowAllConfig(cudnnBackendDescriptor_t engine_config) { + (void)engine_config; + return false; +} + +static cudnn_frontend::Tensor tensor_create( + cudnnDataType_t type, int64_t id, + int64_t const * dim, int64_t const * stride, + bool is_virtual, bool is_value) { + int nbDims = 4; + auto tensor_created = cudnn_frontend::TensorBuilder() + .setDim(nbDims, dim) + .setStride(nbDims, stride) + .setId(id) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(type) + .setVirtual(is_virtual) + .setByValue(is_value) + .build(); + return tensor_created; +} + +static cudnn_frontend::Tensor tensor_create_with_offset( + cudnnDataType_t type, int64_t id, + int64_t const * dim, int64_t const * stride, + bool is_virtual, bool is_value, + std::shared_ptr raggedOffset) { + int nbDims = 4; + auto tensor_created = cudnn_frontend::TensorBuilder() + .setDim(nbDims, dim) + .setStride(nbDims, stride) + .setId(id) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(type) + .setVirtual(is_virtual) + .setByValue(is_value) + .setRaggedOffset(raggedOffset) + .build(); + return tensor_created; +} + +static cudnn_frontend::PointWiseDesc pw_desc_create( + cudnnDataType_t type, cudnnPointwiseMode_t mode) { + auto pw_desc_created = cudnn_frontend::PointWiseDescBuilder() + .setMode(mode) + .setComputeType(type) + .build(); + return pw_desc_created; +} + +static cudnn_frontend::Operation unary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Operation binary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Operation ternary_pw_op_create( + cudnn_frontend::Tensor const &xDesc, + cudnn_frontend::Tensor const &bDesc, + cudnn_frontend::Tensor const &tDesc, + cudnn_frontend::Tensor const &yDesc, + cudnn_frontend::PointWiseDesc const &pwDesc) { + auto pw_op_created = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(xDesc) + .setbDesc(bDesc) + .settDesc(tDesc) + .setyDesc(yDesc) + .setpwDesc(pwDesc) + .build(); + return pw_op_created; +} + +static cudnn_frontend::Tensor createAmax( + const std::string& amax_tensor_name, + const cudnn_frontend::Tensor& prevBlockOutputTensor, + std::vector* ops) { + int64_t amax_dim[4] = {1, 1, 1, 1}; + int64_t amax_stride[4] = {1, 1, 1, 1}; + auto amaxTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], + amax_dim, amax_stride, false, false); + + // Define the amax descriptor + auto reductionDesc = cudnn_frontend::ReductionDescBuilder() + .setMathPrecision(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) + .build(); + + // Create a reduction amax Node + auto reduction_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(amaxTensor) + .setreductionDesc(reductionDesc) + .build(); + ops->push_back(std::move(reduction_op)); + return amaxTensor; +} + +static cudnn_frontend::Tensor createScale( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const std::string& scale_tensor_name, + cudnnDataType_t tensorType, + bool isOutputVirtual, bool isScaleByValue, + std::vector* ops, + const std::string& output_tensor_name ="") { + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + int64_t output_dim[4]; + int64_t output_stride[4]; + + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + + auto scaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value + + int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + + tensor_name_to_uid[scale_tensor_name] + 5000 : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleKTensor = tensor_create( + tensorType, outputUID, output_dim, + output_stride, isOutputVirtual, false); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleKTensor; +} + +static cudnn_frontend::Tensor createScale( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const cudnn_frontend::Tensor& scaleTensor, + cudnnDataType_t tensorType, + bool isOutputVirtual, bool isScaleByValue, + std::vector* ops, + int UID_offset, const std::string& output_tensor_name ="") { + int64_t output_dim[4]; + int64_t output_stride[4]; + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + + int64_t outputUID = isOutputVirtual ? + tensor_name_to_uid["VIRTUAL"] + UID_offset : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = tensor_create( + tensorType, outputUID, output_dim, + output_stride, isOutputVirtual, false); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleTensor; +} + +static cudnn_frontend::Tensor createScaleWithOffset( + const cudnn_frontend::Tensor& prevBlockOutputTensor, + const std::string& scale_tensor_name, + cudnnDataType_t tensorType, + bool isOutputVirtual, + bool isScaleByValue, + std::vector* ops, + std::shared_ptr offsetTensor, + const std::string& output_tensor_name ="") { + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + int64_t output_dim[4]; + int64_t output_stride[4]; + // If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides + if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") { + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + } + generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], + 0 /*s_kv = 0 for placeholder*/, + output_dim[3], output_stride, + NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_QKV_Matrix::NVTE_Q_Matrix); + } else { + // Otherwise output dim and stride should be the same as prev block dim and stride + for (int i = 0; i < 4; i++) { + output_dim[i] = prevBlockOutputTensor.getDim()[i]; + output_stride[i] = prevBlockOutputTensor.getStride()[i]; + } + } + + auto scaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], + scale_dim, scale_stride, false, isScaleByValue); // is by value + + cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType; + int64_t outputUID = isOutputVirtual ? + tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 : + tensor_name_to_uid[output_tensor_name]; + auto afterScaleTensor = tensor_create_with_offset( + outputDataType, outputUID, output_dim, + output_stride, isOutputVirtual, false, offsetTensor); // is virtual + + // Define the scale descriptor + auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a Scale Node + auto scale_op = binary_pw_op_create( + prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); + + ops->push_back(std::move(scale_op)); + return afterScaleTensor; +} + +static cudnn_frontend::Tensor createSoftmaxForward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + std::vector* ops, + const cudnn_frontend::Tensor& prevBlockOutputTensor, + bool isTraining) { + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t afterReduction_dim[4] = {b, h, s_q, 1}; + int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; + + // max (x) (M tensor) + auto afterMaxReductionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], + afterReduction_dim, afterReduction_stride, + !isTraining, false); // not virtual if training is true, + // virtual if training is false + // x - max(x) + auto afterSubtractionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + // e^(x - max(x)) + auto afterExponentTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual; + // sum (e^(x - max(x))) (Z tensor) + auto zTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], + afterReduction_dim, afterReduction_stride, true, false); // is virtual + // 1 / sum (e^(x - max(x))) (Z_INV tensor) + auto zInvTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], + afterReduction_dim, afterReduction_stride, + !isTraining, false); // not virtual if training is true, + // virtual if training is false + // Final softmax output (After exponent * Z_INV) + auto beforeDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) + .build(); + + // Create a reduction max Node + auto reductionMax_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(prevBlockOutputTensor) + .setyDesc(afterMaxReductionTensor) + .setreductionDesc(reductionMaxDesc) + .build(); + + // Define the subtract descriptor + auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtract Node + auto subtract_op = binary_pw_op_create( + prevBlockOutputTensor, afterMaxReductionTensor, + afterSubtractionTensor, subtractDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent Node + auto exponent_op = unary_pw_op_create( + afterSubtractionTensor, afterExponentTensor, exponentDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add Node + auto reductionAdd_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(afterExponentTensor) + .setyDesc(zTensor) + .setreductionDesc(reductionAddDesc) + .build(); + + // Define the reciprocal descriptor + auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); + + // Create a reciprocal Node + auto reciprocal_op = unary_pw_op_create(zTensor, zInvTensor, reciprocalDesc); + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliply_op = binary_pw_op_create( + afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); + + ops->push_back(std::move(reductionMax_op)); + ops->push_back(std::move(subtract_op)); + ops->push_back(std::move(exponent_op)); + ops->push_back(std::move(reductionAdd_op)); + ops->push_back(std::move(reciprocal_op)); + ops->push_back(std::move(mutliply_op)); + + return beforeDropoutTensor; +} + +static cudnn_frontend::Tensor createDropoutForward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + double probability, + std::vector* ops, + const cudnn_frontend::Tensor& beforeDropoutTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Dropout DAG constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // Mask for the dropout + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + auto dropoutSeedTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value + + // After dropout tensor befor scale + auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(tensor_name_to_uid["VIRTUAL"] + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(true) + .setByValue(false) + .setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: + CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // Scale after dropout + auto scaleDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value + // After Scale + auto afterDropout_before_quan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng Node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); + + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto scaleMul_op = binary_pw_op_create( + beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterDropout_before_quan_S; +} + +static cudnn_frontend::Tensor createDropoutBackward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + double probability, + std::vector* ops, + const cudnn_frontend::Tensor& beforeDropoutTensor, + const cudnn_frontend::Tensor& dropoutMaskTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Dropout DAG constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + auto dropoutSeedTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], + scale_dim, scale_stride, false, false); // is by value + auto dropoutOffsetTensor = tensor_create( + CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], + scale_dim, scale_stride, false, false); // is by value + + // After dropout tensor befor scale + auto beforeDropoutScaleTensor = cudnn_frontend::TensorBuilder() + .setDim(4, afterBMM1_dim) + .setStride(4, afterBMM1_stride) + .setId(tensor_name_to_uid["VIRTUAL"] + 201) + .setAlignment(16) // 16B alignment is needed to run a tensor core engine + .setDataType(CUDNN_DATA_FLOAT) + .setVirtual(true) + .setByValue(false) + .setReorderType(cudnn_frontend::cudnnBackendTensorReordering_t:: + CUDNN_TENSOR_REORDERING_F16x16) + .build(); + // Scale after dropout (1 / (1 - p)) + auto scaleDropoutTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], + scale_dim, scale_stride, false, true); // is by value + // After Scale + auto afterDropout_before_quan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the reduction descriptor + auto rngDesc = cudnn_frontend::RngDescBuilder() + .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) + .setBernoulliDistProbability(1.0 - probability) + .build(); + + // Create a rng Node + auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) + .setyDesc(dropoutMaskTensor) + .setSeedDesc(dropoutSeedTensor) + .setOffsetDesc(dropoutOffsetTensor) + .setRngDesc(rngDesc) + .build(); + + // Define the multiply mask descriptor + auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + beforeDropoutTensor, dropoutMaskTensor, + beforeDropoutScaleTensor, maskMulDesc); + + // Define the multiply scale descriptor + auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto scaleMul_op = binary_pw_op_create( + beforeDropoutScaleTensor, scaleDropoutTensor, + afterDropout_before_quan_S, scaleMulDesc); + + ops->push_back(std::move(rng_op)); + ops->push_back(std::move(maskMul_op)); + ops->push_back(std::move(scaleMul_op)); + + return afterDropout_before_quan_S; +} + +static cudnn_frontend::Tensor createSoftmaxBackward( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, + std::vector* ops, + const cudnn_frontend::Tensor& dyTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "Softmax backward constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t dx_dim[4] = {b, h, s_q, s_kv}; + int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + int64_t M_Z_dim[4] = {b, h, s_q, 1}; + int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1}; + + // Creating all tensors + auto MTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], + M_Z_dim, M_Z_stride, false, false); // not virtual + auto ZInvTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], + M_Z_dim, M_Z_stride, false, false); // not virtual + auto dxAfterSubtractionTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, + dx_dim, dx_stride, true, false); // is virtual + auto dxAfterExponentiation = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, + dx_dim, dx_stride, true, false); // is virtual + auto dxBeforeDropout_QKt_Tensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, + dx_dim, dx_stride, true, false); // is virtual + + // Creating all ops + // sub (dy - M) + auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + auto subtractionOp = binary_pw_op_create( + dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); + + // Define the exponent descriptor + auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); + + // Create a exponent Node. (exp(dy - M)) + auto exponentOp = unary_pw_op_create( + dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliplyOp = binary_pw_op_create( + dxAfterExponentiation, ZInvTensor, dxBeforeDropout_QKt_Tensor, multiplyDesc); + + ops->push_back(std::move(subtractionOp)); + ops->push_back(std::move(exponentOp)); + ops->push_back(std::move(mutliplyOp)); + + return dxBeforeDropout_QKt_Tensor; +} + +static cudnn_frontend::Tensor createQKBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &qTensor, + const cudnn_frontend::Tensor &kTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + // Creates the necessary tensor descriptors + int64_t k_transpose_dim[4] = {b, h, d, s_kv}; + int64_t k_transpose_stride[4]; + generateMatrixStrides( + b, h, s_q, s_kv, d, + k_transpose_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto kTransposeTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K_TRANSPOSE"], + k_transpose_dim, k_transpose_stride, + false, false, QKVRaggedOffsetTensor); // is virtual + + // First GEMM output + auto afterQKTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, + s_dim, s_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(-2000000) + .build(); + + // Create reshape node for K -> K.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(kTensor) + .setyDesc(kTransposeTensor) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(qTensor) + .setbMatDesc(kTransposeTensor) + .setcMatDesc(afterQKTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return afterQKTensor; +} + +static cudnn_frontend::Tensor createSVBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &softmaxTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + cudnn_frontend::throw_if(ops->size() == 0, + "BMM2 op constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + + auto vTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V"], + v_dim, v_stride, false, false, QKVRaggedOffsetTensor); + // Second fprop GEMM output + auto oTensor = tensor_create( + tensorType, tensor_name_to_uid["VIRTUAL"] + 300, + o_dim, o_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(softmaxTensor) + .setbMatDesc(vTensor) + .setcMatDesc(oTensor) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(matmulOp)); + + return oTensor; +} + +static cudnn_frontend::Tensor createSdOBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &softmaxTensor, + const cudnn_frontend::Tensor &dOTensor, + const cudnn_frontend::Tensor &mnkOverride) { + cudnn_frontend::throw_if(ops->size() == 0, + "BMM2 op constructed incorrectly as the first one", + CUDNN_STATUS_BAD_PARAM); + + int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; + int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv}; + + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1}; + + auto sTransposeTensor = tensor_create( + tensorType, tensor_name_to_uid["VIRTUAL"] + 499, + s_dim_transpose, s_stride_transpose, + true, false); // is virtual + // S.T * dO + auto dVTensor_before_dequan_S = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, + v_dim, v_stride, + true, false); // is virtual + + // Create reshape node for softmax -> softmax.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(softmaxTensor) + .setyDesc(sTransposeTensor) + .build(); + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(sTransposeTensor) + .setbMatDesc(dOTensor) + .setcMatDesc(dVTensor_before_dequan_S) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return dVTensor_before_dequan_S; +} + +static cudnn_frontend::Tensor createdOVBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + cudnnDataType_t tensorType, + std::vector* ops, + const cudnn_frontend::Tensor &dOTensor, + const cudnn_frontend::Tensor &mnkOverride, + std::shared_ptr QKVRaggedOffsetTensor) { + // Creates the necessary tensor descriptors + int64_t v_dim[4] = {b, h, s_kv, d}; + int64_t v_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + + int64_t v_transpose_dim[4] = {b, h, d, s_kv}; + int64_t v_transpose_stride[4]; + v_transpose_stride[0] = v_stride[0]; + v_transpose_stride[1] = v_stride[1]; + v_transpose_stride[2] = v_stride[3]; + v_transpose_stride[3] = v_stride[2]; + + int64_t s_dim[4] = {b, h, s_q, s_kv}; + int64_t s_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + auto vTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V"], + v_dim, v_stride, + false, false, QKVRaggedOffsetTensor); + auto vTransposeTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["V_TRANSPOSE"], + v_transpose_dim, v_transpose_stride, + false, false, QKVRaggedOffsetTensor); // is virtual + + // dO * V.T + auto afterdOVTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, + s_dim, s_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create reshape node for V -> V.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(vTensor) + .setyDesc(vTransposeTensor) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dOTensor) + .setbMatDesc(vTransposeTensor) + .setcMatDesc(afterdOVTensor) + .setmOverrideDesc(mnkOverride) + .setnOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return afterdOVTensor; +} + +static cudnn_frontend::Tensor createdOAndORowReductionChain( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &O_after_dequan, + const cudnn_frontend::Tensor &dO_after_dequan, + const cudnn_frontend::Tensor &dropoutScale_dOVt_OdO_Tensor) { + int64_t o_dim[4] = {b, h, s_q, d}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + int64_t o_dim_row_sum[4] = {b, h, s_q, 1}; + int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1}; + + auto O_dO_after_pointwise_multiply = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, + o_dim, o_stride, true, false); // is virtual + auto O_dO_after_dropout_scale = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, + o_dim, o_stride, true, false); // is virtual + auto O_dO_after_rowsum = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, + o_dim_row_sum, o_dim_row_sum_stride, true, false); // is virtual + + // Define the pw multiply descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply Node + auto mutliply_op = binary_pw_op_create( + O_after_dequan, dO_after_dequan, + O_dO_after_pointwise_multiply, multiplyDesc); + + // Create multiply node with dropout scale + auto dropout_scale_multiply_op = binary_pw_op_create( + O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, + O_dO_after_dropout_scale, multiplyDesc); + + // Define the reduction descriptor + auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) + .build(); + + // Create a reduction add Node + auto reductionAdd_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) + .setxDesc(O_dO_after_dropout_scale) + .setyDesc(O_dO_after_rowsum) + .setreductionDesc(reductionAddDesc) + .build(); + + ops->push_back(std::move(mutliply_op)); + ops->push_back(std::move(dropout_scale_multiply_op)); + ops->push_back(std::move(reductionAdd_op)); + + return O_dO_after_rowsum; +} + +static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &dS_after_dropout, + const cudnn_frontend::Tensor &AfterDropout_before_quan_S, + const cudnn_frontend::Tensor &O_dO_after_rowsum, + const cudnn_frontend::Tensor &attnScale) { + int64_t o_dim[4] = {b, h, s_q, s_kv}; + int64_t o_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + auto dS_minus_O_dO = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, + o_dim, o_stride, true, false); // is virtual + auto AfterAttnScale_before_dS = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, + o_dim, o_stride, true, false); // is virtual + auto S_mul_dS_minus_O_dO = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, + o_dim, o_stride, true, false); // is virtual + + // Define the pw subtraction descriptor + auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); + + // Create a subtraction Node + auto sub_op = binary_pw_op_create( + dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); + + // Define the pw multiplication descriptor + auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // dS_minus_O_dO * attnScale + auto mutliply_attn_scale_op = binary_pw_op_create( + dS_minus_O_dO, attnScale, + AfterAttnScale_before_dS, multiplyDesc); + + // AfterDropout_before_quan_S * AfterAttnScale_before_dS + auto mutliply_op = binary_pw_op_create( + AfterDropout_before_quan_S, AfterAttnScale_before_dS, + S_mul_dS_minus_O_dO, multiplyDesc); + + ops->push_back(std::move(sub_op)); + ops->push_back(std::move(mutliply_attn_scale_op)); + ops->push_back(std::move(mutliply_op)); + + return S_mul_dS_minus_O_dO; +} + +static cudnn_frontend::Tensor createdSKBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + std::vector* ops, + const cudnn_frontend::Tensor &dSTensor, + const cudnn_frontend::Tensor &kTensor, + const cudnn_frontend::Tensor &mnkOverride) { + // Creates the necessary tensor descriptors + int64_t after_dSK_dim[4] = {b, h, s_kv, d}; + int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1}; + // dS * K + auto After_dS_K = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, + after_dSK_dim, after_dSK_stride, true, false); // is virtual + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTensor) + .setbMatDesc(kTensor) + .setcMatDesc(After_dS_K) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(matmulOp)); + + return After_dS_K; +} + +static cudnn_frontend::Tensor createdSQBMM( + int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, + NVTE_QKV_Layout layout, + std::vector* ops, + const cudnn_frontend::Tensor &dSTensor, + const cudnn_frontend::Tensor &qTensor, + const cudnn_frontend::Tensor &mnkOverride) { + // Creates the necessary tensor descriptors + int64_t dS_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); + + int64_t dS_transpose_dim[4] = {b, h, s_kv, s_q}; + int64_t dS_transpose_stride[4]; + dS_transpose_stride[0] = dS_stride[0]; + dS_transpose_stride[1] = dS_stride[1]; + dS_transpose_stride[2] = dS_stride[3]; + dS_transpose_stride[3] = dS_stride[2]; + + int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d}; + int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1}; + + auto dSTransposeTensor = tensor_create( + CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, + dS_transpose_dim, dS_transpose_stride, true, false); // is virtual + + // dS.T * Q + auto After_dSTranspose_Q = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, + after_dSTranspose_Q_dim, after_dSTranspose_Q_stride, + true, false); // is virtual + + // Create reshape node for V -> V.T + auto reshape_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) + .setxDesc(dSTensor) + .setyDesc(dSTransposeTensor) + .build(); + + // Define the matmul desc + auto matmulDesc = cudnn_frontend::MatMulDescBuilder() + .setComputeType(CUDNN_DATA_FLOAT) + .setPaddingValue(0) + .build(); + + // Create a matmul Node + auto matmulOp = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) + .setaMatDesc(dSTransposeTensor) + .setbMatDesc(qTensor) + .setcMatDesc(After_dSTranspose_Q) + .setmOverrideDesc(mnkOverride) + .setkOverrideDesc(mnkOverride) + .setmatmulDesc(matmulDesc) + .build(); + + ops->push_back(std::move(reshape_op)); + ops->push_back(std::move(matmulOp)); + + return After_dSTranspose_Q; +} + +// fused attention FWD FP8 +void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, + bool isTraining, float attnScale, + float dropoutProbability, NVTE_QKV_Layout layout, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, + void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, + void* workspace_ptr, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle_) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + FADescriptor descriptor{ + b, h, s_q, s_kv, d, + attnScale, isTraining, dropoutProbability, layout, tensorType}; + + using CacheType = std::map; + static CacheType fa_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + cudnn_frontend::throw_if(dropoutProbability != 0.0f && !isTraining, + "Dropout probability should be 0.0f for inference mode", + CUDNN_STATUS_BAD_PARAM); + cudnn_frontend::throw_if(dropoutProbability == 1.0f, + "Dropout probability cannot be 1.0", + CUDNN_STATUS_BAD_PARAM); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create override tensors + auto seqlenMNKTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + // Create shared ptrs to ragged offset tensors + // for multiple tensors to use ragged offset + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["Q"], + q_dim, q_stride, false, false, + QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K"], + k_dim, k_stride, false, false, + QKVRaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, qTensor, kTensor, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = createScale( + afterQKTensor, // input tensor + "AttnScale", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = createScale( + AfterAttnScale_before_dequan_Q_tensor, // input tensor + "descaleQ", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = createScale( + AfterAttnScale_before_dequan_K_tensor, // input tensor + "descaleK", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + auto BeforeDropoutTensor = createSoftmaxForward( + b, h, s_q, s_kv, &ops, + AfterAttnScale_tensor, isTraining); + + auto AfterDropout_before_quan_S = createDropoutForward( + b, h, s_q, s_kv, dropoutProbability, + &ops, BeforeDropoutTensor); + + // Amax for S + createAmax("amaxS", BeforeDropoutTensor, &ops); + + // After softmax * dropout * scale S -> fp8 input to next bmm with V + auto AfterMultiplyDropout = createScale( + AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * Dropout * V + auto OTensor_before_dequan_S_tensor = createSVBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, AfterMultiplyDropout, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // O * dequant_S + auto OTensor_before_dequan_V_tensor = createScale( + OTensor_before_dequan_S_tensor, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V + auto OTensor_before_quan_O_tensor = createScale( + OTensor_before_dequan_V_tensor, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_V * scale O + auto OTensor = createScaleWithOffset( + OTensor_before_quan_O_tensor, // input tensor + "scaleO", // scale tensor + tensorType, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + ORaggedOffsetTensorPtr, // ragged offset + "O"); + + // Amax for O + createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, + allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, + CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_fprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; // end of get_plan + + auto plan = get_plan(fa_fprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); + + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + int32_t* qkv_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x)/blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), + actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + float dropoutScale = 1.0f/(1.0f - dropoutProbability); + + std::set> data_ptrs; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleO"], devPtrScaleO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxO"], devPtrAmaxO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxS"], devPtrAmaxS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + // If training, then we need to write out M and Z_INV + if (isTraining) { + data_ptrs.emplace(std::pair( + tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["Z_INV"], devPtrZInv)); + } + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + cudnnStatus_t status = cudnnBackendExecute( + handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + + cudnn_frontend::throw_if( + [status]() { return (status != CUDNN_STATUS_SUCCESS); }, + "Plan execute error", status); + } catch (cudnn_frontend::cudnnException& e) { + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) + && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH + || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } + } +} + +// fused attention BWD FP8 +void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, + float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledS, + void* devPtrScaleS, void* devPtrScaledS, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, + void* devPtrAmaxdS, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnnDataType_t tensorType, + void* workspace_ptr, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle_) { + try { + NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); + + FADescriptor descriptor{ + b, h, s_q, s_kv, d, + attnScale, false, dropoutProbability, layout, tensorType}; + + using CacheType = std::map; + static CacheType fa_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { + // If hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto plan = it->second; + return plan; + } + + // Otherwise, build the op_graph and the plan. Then update cache + std::vector all_ops; + std::vector ops; + + cudnn_frontend::throw_if(dropoutProbability == 1.0f, + "Dropout probability cannot be 1.0", + CUDNN_STATUS_BAD_PARAM); + + int64_t raggedDim[4] = {b + 1, 1, 1, 1}; + int64_t raggedStride[4] = {1, 1, 1, 1}; + // Create offset tensors + auto QKVOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], + raggedDim, raggedStride, false, false); + auto ORaggedOffsetTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], + raggedDim, raggedStride, false, false); + + // Create shared ptrs to ragged offset tensors for multiple tensors + std::shared_ptr QKVRaggedOffsetTensorPtr = + std::make_shared(std::move(QKVOffsetTensor)); + std::shared_ptr ORaggedOffsetTensorPtr = + std::make_shared(std::move(ORaggedOffsetTensor)); + + // Create Q and K tensors that are used in different places + int64_t q_dim[4] = {b, h, s_q, d}; + int64_t q_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, + NVTE_QKV_Matrix::NVTE_Q_Matrix); + + int64_t k_dim[4] = {b, h, s_kv, d}; + int64_t k_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + + auto qTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["Q"], + q_dim, q_stride, false, false, QKVRaggedOffsetTensorPtr); + auto kTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["K"], + k_dim, k_stride, false, false, QKVRaggedOffsetTensorPtr); + + int64_t scale_dim[4] = {1, 1, 1, 1}; + int64_t scale_stride[4] = {1, 1, 1, 1}; + + // Create attnScale tensor for multiple ops to use + auto attnScaleTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], + scale_dim, scale_stride, false, true); // is by value + + // Create descale Q K dO dS global tensors since they are used in multiple places + auto descaleQTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], + scale_dim, scale_stride, false, false); + auto descaleKTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], + scale_dim, scale_stride, false, false); + auto descaledOTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], + scale_dim, scale_stride, false, false); + auto descaledSTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], + scale_dim, scale_stride, false, false); + + int64_t seqlen_dim[4] = {b, 1, 1, 1}; + int64_t seqlen_stride[4] = {1, 1, 1, 1}; + // Create MNK override tensor + auto seqlenMNKTensor = tensor_create( + CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], + seqlen_dim, seqlen_stride, false, false); + + int64_t O_dim[4] = {b, h, s_q, d}; + int64_t O_stride[4]; + generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, + NVTE_QKV_Matrix::NVTE_O_Matrix); + // Create O and loss tensor + auto OTensor = tensor_create_with_offset( + tensorType, tensor_name_to_uid["O"], + O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); + // dO is used in multiple places and E5M2 + auto dOTensor = tensor_create_with_offset( + CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], + O_dim, O_stride, false, false, ORaggedOffsetTensorPtr); + + // Q * K.T + auto afterQKTensor = createQKBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, qTensor, kTensor, + seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // QK.T * attn scale + auto AfterAttnScale_before_dequan_Q_tensor = createScale( + afterQKTensor, // input tensor + attnScaleTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + true, // scale is by value + &ops, + 1999 /*UID offset*/); + + // QK.T * attn scale * dequant_Q + auto AfterAttnScale_before_dequan_K_tensor = createScale( + AfterAttnScale_before_dequan_Q_tensor, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2000 /*UID offset*/); + + // QK.T * attn scale * dequant_Q * dequant_K + auto AfterAttnScale_tensor = createScale( + AfterAttnScale_before_dequan_K_tensor, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2001 /*UID offset*/); + + auto beforeDropout_QKt_Tensor = createSoftmaxBackward( + b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); + + int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; + int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; + + // mask for the dropout. Used in different places + auto dropoutMaskTensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + auto AfterDropout_before_quan_S = createDropoutBackward( + b, h, s_q, s_kv, dropoutProbability, + &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); + + // After softmax * scale S -> fp8 input to next bmm with V + auto AfterMultiply = createScale( + AfterDropout_before_quan_S, // input tensor + "scaleS", // scale tensor + tensorType, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // After softmax * dO + auto dVTensor_before_dequan_S = createSdOBMM( + b, h, s_q, s_kv, d, tensorType, + &ops, AfterMultiply, dOTensor, seqlenMNKTensor); + + // O * dequant_S + auto dVTensor_before_dequan_dO = createScale( + dVTensor_before_dequan_S, // input tensor + "descaleS", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // O * dequant_S * dequant_dO + auto dVTensor_before_quan_dV = createScale( + dVTensor_before_dequan_dO, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2002 /*UID offset*/); + + // O * dequant_S * dequant_dO * scale dV + auto dVTensor = createScaleWithOffset( + dVTensor_before_quan_dV, // input tensor + "scaledV", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dV" /*Output tensor name*/); + + // Amax for dV + createAmax("amaxdV", dVTensor_before_quan_dV, &ops); + + auto dS_before_dequan_dO_Tensor = createdOVBMM( + b, h, s_q, s_kv, d, layout, tensorType, + &ops, dOTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); + + // dS * dequant_dO + auto dS_before_dequan_V = createScale( + dS_before_dequan_dO_Tensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2003 /*UID offset*/); + + // O * dequant_S * dequant_dV + auto dS_after_dequan = createScale( + dS_before_dequan_V, // input tensor + "descaleV", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // RNG Multiply + auto beforeDropoutScale_dOVt_Tensor = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + // After dropout mask and scale + auto dS_after_dropout = tensor_create( + CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, + afterBMM1_dim, afterBMM1_stride, true, false); // is virtual + + // Define the multiply mask descriptor + auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); + + // Create a multiply mask Node + auto maskMul_op = binary_pw_op_create( + dS_after_dequan, dropoutMaskTensor, + beforeDropoutScale_dOVt_Tensor, mulDesc); + + ops.push_back(std::move(maskMul_op)); + + // scale after dropout for dO and O chain + auto dropoutScale_dOVt_OdO_Tensor = tensor_create( + tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], + scale_dim, scale_stride, false, true); // is by value + + // Create a multiply dropout scale Node + auto mul_dropout_scale_op = binary_pw_op_create( + beforeDropoutScale_dOVt_Tensor, + dropoutScale_dOVt_OdO_Tensor, + dS_after_dropout, mulDesc); + + ops.push_back(std::move(mul_dropout_scale_op)); + + // O * dequant_O + auto O_after_dequan_Tensor = createScale(OTensor, // input tensor + "descaleO", // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // dO * dequant_dO + auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor + descaledOTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2004 /*UID offset*/); + + // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] + auto O_dO_after_rowsum = createdOAndORowReductionChain( + b, h, s_q, s_kv, d, layout, + &ops, O_after_dequan_Tensor, + dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); + + // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale + auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( + b, h, s_q, s_kv, d, layout, + &ops, dS_after_dropout, + AfterDropout_before_quan_S, O_dO_after_rowsum, + attnScaleTensor); + + + // S_mul_dS_minus_O_dO * scaledS + auto S_mul_dS_minus_O_dO_after_quan_dS = createScale( + S_mul_dS_minus_O_dO, // input tensor + "scaledS", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + true, // output is virtual + false, // scale is by value + &ops); + + // Amax for dS + createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); + + // dS @ K + auto After_dS_K = createdSKBMM( + b, h, s_q, s_kv, d, &ops, + S_mul_dS_minus_O_dO_after_quan_dS, + kTensor, seqlenMNKTensor); + + // (dS * K) * descale dS + auto After_dS_K_before_dequan_K = createScale( + After_dS_K, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2006 /*UID offset*/); + + // (dS * K) * descale dS * descale K + auto After_dS_K_before_quan_dQ = createScale( + After_dS_K_before_dequan_K, // input tensor + descaleKTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2007 /*UID offset*/); + + // (dS * K) * descale dS * descale K * scale dQ + auto dQ = createScaleWithOffset( + After_dS_K_before_quan_dQ, // input tensor + "scaledQ", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dQ"); + + // Amax for dQ + createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); + + // dS.T @ Q + auto After_dSTranspose_Q = createdSQBMM( + b, h, s_q, s_kv, d, layout, &ops, + S_mul_dS_minus_O_dO_after_quan_dS, + qTensor, seqlenMNKTensor); + + // (dS.T * Q) * descale dS + auto After_dSTranspose_Q_before_dequan_Q = createScale( + After_dSTranspose_Q, // input tensor + descaledSTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2009 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q + auto After_dSTranspose_Q_before_quan_dK = createScale( + After_dSTranspose_Q_before_dequan_Q, // input tensor + descaleQTensor, // scale tensor + CUDNN_DATA_FLOAT, // output tensor type + true, // output is virtual + false, // scale is by value + &ops, + 2010 /*UID offset*/); + + // (dS.T * Q) * descale dS * descale Q * scale dK + auto dK = createScaleWithOffset( + After_dSTranspose_Q_before_quan_dK, // input tensor + "scaledK", // scale tensor + CUDNN_DATA_FP8_E5M2, // output tensor type + false, // output not virtual + false, // scale is by value + &ops, + QKVRaggedOffsetTensorPtr, // ragged offset + "dK"); + + // Amax for dK + createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); + + for (unsigned int i = 0; i < ops.size(); i++) { + all_ops.push_back(&ops[i]); + } + + // Create an Operation Graph + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(all_ops.size(), all_ops.data()) + .build(); + + cudnn_frontend::EngineConfigList filtered_configs; + auto statuses = cudnn_frontend::get_heuristics_list<1>( + {"heuristics_instant"}, opGraph, + allowAllConfig, filtered_configs, true); + + if (filtered_configs.size() == 0) { + cudnn_frontend::set_error_and_throw_exception( + nullptr, + CUDNN_STATUS_NOT_SUPPORTED, + "run_mha_bprop: No config returned by the heuristics"); + } + + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(filtered_configs[0], opGraph.getTag()) + .build(); + cache.insert({descriptor, plan}); + return plan; + }; + + auto plan = get_plan(fa_bprop_cache, descriptor); + size_t wkspace_size = static_cast(plan.getWorkspaceSize()); + + // Exit to request upper level API to allocate memory if needed + if (workspace_ptr == nullptr) { + *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); + return; + } + + int32_t* qkv_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + wkspace_size); + int32_t* o_ragged_offset = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * sizeof(int32_t)); + int32_t* actual_seqlens_q = reinterpret_cast( + reinterpret_cast(workspace_ptr) + + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); + // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV + dim3 blockDims(128); + dim3 gridDims((b + blockDims.x)/blockDims.x); + cu_seqlens_to_offsets<<>>( + b, h, d, reinterpret_cast(devPtrcuSeqlensQ), + actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); + void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); + void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); + + std::set> data_ptrs; + float dropoutScale = 1.0f/(1.0f - dropoutProbability); + float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; + data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["AttnScale"], &attnScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], + &dropoutScale_dOVt_OdO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); + data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleK"], devPtrDescaleK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleV"], devPtrDescaleV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleS"], devPtrDescaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaledS"], devPtrDescaledS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaleO"], devPtrDescaleO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["descaledO"], devPtrDescaledO)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaleS"], devPtrScaleS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledS"], devPtrScaledS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledQ"], devPtrScaledQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledK"], devPtrScaledK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["scaledV"], devPtrScaledV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); + data_ptrs.emplace(std::pair( + tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); + + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(data_ptrs) + .build(); + cudnnStatus_t status = cudnnBackendExecute( + handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + + cudnn_frontend::throw_if( + [status]() { return (status != CUDNN_STATUS_SUCCESS); }, + "Plan execute error", status); + } catch (cudnn_frontend::cudnnException& e) { + struct cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); + + // This example is only for GH100 cards (cudnn Version >= 8900) + if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) + && (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH + || e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { + std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; + } else { + std::cout << "[ERROR] Exception " << e.what() << std::endl; + } + } +} + +#endif + +} // namespace fused_attn + +#if (CUDNN_VERSION >= 8900) +// fused attention FWD FP8 with packed QKV +void fused_attn_fwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_Output_Tensors, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + // QKV shape is [total_seqs, 3, h, d] + void* devPtrQKV = input_QKV->data.dptr; + void* devPtrQ = reinterpret_cast(devPtrQKV); + void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); + void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + void* devPtrDescaleQ = input_QKV->scale_inv.dptr; + void* devPtrDescaleK = input_QKV->scale_inv.dptr; + void* devPtrDescaleV = input_QKV->scale_inv.dptr; + + void* devPtrO = output_O->data.dptr; + void* devPtrAmaxO = output_O->amax.dptr; + void* devPtrScaleO = output_O->scale.dptr; + + void* devPtrM = nullptr; + void* devPtrZInv = nullptr; + if (Aux_Output_Tensors->size == 0) { + if (is_training) { + Aux_Output_Tensors->size = 2; + Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + output_M->data.dptr = nullptr; + output_M->data.shape = {b, h, max_seqlen, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {b, h, max_seqlen, 1}; + output_ZInv->data.dtype = DType::kFloat32; + } + } else if (Aux_Output_Tensors->size == 2) { + Tensor *output_M = reinterpret_cast(Aux_Output_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_Output_Tensors->tensors[1]); + devPtrM = output_M->data.dptr; + devPtrZInv = output_ZInv->data.dptr; + } + + void* devPtrAmaxS = input_output_S->amax.dptr; + void* devPtrScaleS = input_output_S->scale.dptr; + void* devPtrDescaleS = input_output_S->scale_inv.dptr; + + void* devPtrcuSeqlens = reinterpret_cast( + reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn::fa_fwd_fp8( + b, max_seqlen, max_seqlen, h, d, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +// fused attention BWD FP8 with packed QKV +void fused_attn_bwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQKV, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + // QKV shape is [total_seqs, 3, h, d] + void* devPtrQKV = input_QKV->data.dptr; + void* devPtrQ = reinterpret_cast(devPtrQKV); + void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); + void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + void* devPtrDescaleQ = input_QKV->scale_inv.dptr; + void* devPtrDescaleK = input_QKV->scale_inv.dptr; + void* devPtrDescaleV = input_QKV->scale_inv.dptr; + + void* devPtrO = input_O->data.dptr; + void* devPtrDescaleO = input_O->scale_inv.dptr; + void* devPtrdO = input_dO->data.dptr; + void* devPtrDescaledO = input_dO->scale_inv.dptr; + + void* devPtrM = input_M->data.dptr; + void* devPtrZInv = input_ZInv->data.dptr; + + void* devPtrScaleS = input_S->scale.dptr; + void* devPtrDescaleS = input_S->scale_inv.dptr; + void* devPtrAmaxdS = input_output_dP->amax.dptr; + void* devPtrScaledS = input_output_dP->scale.dptr; + void* devPtrDescaledS = input_output_dP->scale_inv.dptr; + + // dQKV shape is [total_seqs, 3, h, d] + void* devPtrdQKV = output_dQKV->data.dptr; + void* devPtrdQ = reinterpret_cast(devPtrdQKV); + void* devPtrdK = reinterpret_cast(reinterpret_cast(devPtrdQKV) + h * d); + void* devPtrdV = reinterpret_cast(reinterpret_cast(devPtrdQKV) + 2 * h * d); + void* devPtrAmaxdQ = output_dQKV->amax.dptr; + void* devPtrAmaxdK = output_dQKV->amax.dptr; + void* devPtrAmaxdV = output_dQKV->amax.dptr; + void* devPtrScaledQ = output_dQKV->scale.dptr; + void* devPtrScaledK = output_dQKV->scale.dptr; + void* devPtrScaledV = output_dQKV->scale.dptr; + + void* devPtrcuSeqlens = reinterpret_cast( + reinterpret_cast(cu_seqlens->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + const DType QKV_type = input_QKV->data.dtype; + size_t workspace_size = 0; + + fused_attn::fa_bwd_fp8( + b, max_seqlen, max_seqlen, h, d, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledS, + devPtrScaleS, devPtrScaledS, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdS, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +#endif // end of CUDNN>=8900 +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h new file mode 100644 index 0000000000..928e128737 --- /dev/null +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +#if (CUDNN_VERSION >= 8900) +// fused attention FWD FP8 with packed QKV +void fused_attn_fwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_Output_Tensors, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); + +// fused attention BWD FP8 with packed QKV +void fused_attn_bwd_fp8_qkvpacked( + size_t b, size_t max_seqlen, + size_t h, size_t d, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + const Tensor *input_QKV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQKV, + const Tensor *cu_seqlens, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); +#endif // end of CUDNN>=8900 +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu new file mode 100644 index 0000000000..5b0b03cb3e --- /dev/null +++ b/transformer_engine/common/fused_attn/utils.cu @@ -0,0 +1,167 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/fused_attn.h" +#include "../common.h" +#include "utils.h" + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +// get matrix strides based on matrix type +void generateMatrixStrides( + int64_t b, int64_t h, + int64_t s_q, int64_t s_kv, + int64_t d, int64_t* strideA, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) { + constexpr int batch_dim_idx = 0; + constexpr int head_dim_idx = 1; + constexpr int seqlen_dim_idx = 2; + constexpr int hidden_dim_idx = 3; + + constexpr int seqlen_transpose_dim_idx = 3; + constexpr int hidden_transpose_dim_idx = 2; + + constexpr int seqlen_q_dim_idx = 2; + constexpr int seqlen_kv_dim_idx = 3; + + switch (matrix) { + case NVTE_QKV_Matrix::NVTE_Q_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * 3 * h * d; + } else { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_K_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[hidden_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[seqlen_transpose_dim_idx] = 2 * h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_V_Matrix: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = 2* h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[hidden_dim_idx] = 1; + strideA[seqlen_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose: + if (layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 3 * h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 3 * h * d; + } else if (layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED) { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = 2* h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * 2 * h * d; + } else { + strideA[hidden_transpose_dim_idx] = 1; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_kv * h * d; + } + break; + case NVTE_QKV_Matrix::NVTE_S_Matrix: + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = s_kv; + strideA[head_dim_idx] = s_q * s_kv; + strideA[batch_dim_idx] = h * s_q * s_kv; + break; + case NVTE_QKV_Matrix::NVTE_O_Matrix: + strideA[seqlen_kv_dim_idx] = 1; + strideA[seqlen_q_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[batch_dim_idx] = s_q * h * d; + break; + } +} + +// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, + int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, + int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < b) { + actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid]; + } + if (tid < b + 1) { + qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d; + o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; + } +} +} // namespace fused_attn + +// get cuDNN data type +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDNN_DATA_HALF; + case DType::kFloat32: + return CUDNN_DATA_FLOAT; + case DType::kBFloat16: + return CUDNN_DATA_BFLOAT16; + case DType::kFloat8E4M3: + return CUDNN_DATA_FP8_E4M3; + case DType::kFloat8E5M2: + return CUDNN_DATA_FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h new file mode 100644 index 0000000000..371a19990e --- /dev/null +++ b/transformer_engine/common/fused_attn/utils.h @@ -0,0 +1,90 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_ + +#include "transformer_engine/transformer_engine.h" +#include + +namespace transformer_engine { +namespace fused_attn { + +using namespace transformer_engine; + +enum NVTE_QKV_Matrix { + NVTE_Q_Matrix = 0, // queries + NVTE_K_Matrix = 1, // keys + NVTE_K_Matrix_Transpose = 2, // keys transposed + NVTE_V_Matrix = 3, // values + NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_S_Matrix = 5, // output of GEMM1 + NVTE_O_Matrix = 6, // final output +}; + +void generateMatrixStrides( + int64_t b, int64_t h, + int64_t s_q, int64_t s_kv, + int64_t d, int64_t* strideA, + NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); + +struct FADescriptor { + std::int64_t b; + std::int64_t h; + std::int64_t s_q; + std::int64_t s_kv; + std::int64_t d; + float attnScale; + bool isTraining; + float dropoutProbability; + NVTE_QKV_Layout layout; + cudnnDataType_t tensor_type; + + bool operator<(const FADescriptor &rhs) const { + return std::tie(b, h, s_q, s_kv, d, + attnScale, isTraining, dropoutProbability, + layout, tensor_type) < std::tie( + rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, + rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.tensor_type); + } +}; + +__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, + int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, + int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); + +} // namespace fused_attn + +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); + +class cudnnExecutionPlanManager { + public: + static cudnnExecutionPlanManager &Instance() { + static thread_local cudnnExecutionPlanManager instance; + return instance; + } + + cudnnHandle_t GetCudnnHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { cudnnCreate(&handle_); }); + return handle_; + } + + ~cudnnExecutionPlanManager() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { + if (handle_ != nullptr) { + cudnnDestroy(handle_); + }}); + } + + private: + cudnnHandle_t handle_ = nullptr; +}; +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h new file mode 100644 index 0000000000..bb9262de18 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -0,0 +1,262 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum NVTE_QKV_Layout { +/*!< separate Q, K, V tensors: + Q: [total_seqs_q, num_heads, head_dim] + | Q Q Q ... Q + | \___________ _____________/ + total_seqs_q <| \/ + | num_heads * head_dim + K: [total_seqs_kv, num_heads, head_dim] + | K K K ... K + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + V: [total_seqs_kv, num_heads, head_dim] + | V V V ... V + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + */ + NVTE_NOT_INTERLEAVED = 0, + +/*!< packed QKV tensor: + QKV: [total_seqs, 3, num_heads, head_dim] + | Q Q Q ... Q K K K ... K V V V ... V + | \___________ _____________/ + total_seqs <| \/ + | num_heads * head_dim + */ + NVTE_QKV_INTERLEAVED = 1, + +/*!< Q and packed KV tensor: + Q: [total_seqs_q, num_heads, head_dim] + | Q Q Q ... Q + | \___________ _____________/ + total_seqs_q <| \/ + | num_heads * head_dim + KV: [total_seqs_kv, 2, num_heads, head_dim] + | K K K ... K V V V ... V + | \___________ _____________/ + total_seqs_kv <| \/ + | num_heads * head_dim + */ + NVTE_KV_INTERLEAVED = 2 +}; + +enum NVTE_Bias_Type { + NVTE_NO_BIAS = 0, /*!< no bias */ + NVTE_PRE_SCALE_BIAS = 1, /*!< bias before scale */ + NVTE_POST_SCALE_BIAS = 2 /*!< bias after scale */ +}; + +enum NVTE_Mask_Type { + NVTE_PADDING_MASK = 0, /*!< padding attention mask */ + NVTE_CAUSAL_MASK = 1, /*!< causal attention mask */ + NVTE_NO_MASK = 2 /*!< no masking */ +}; + +/*! \brief Compute dot product attention with packed QKV input. + * + * Computes: + * - P = Q * K.T + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * V.T + * + * Support Matrix: + * | precision | qkv layout | bias | mask | sequence length | head_dim | + * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | + * + * + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens, + const NVTETensor rng_state, + size_t max_seqlen, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * Support Matrix: + * | precision | qkv layout | bias | mask | sequence length | head_dim | + * | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | <= 512 | 64 | + * + * + * \param[in] QKV The QKV tensor in packed format, + * [total_seqs, 3, num_heads, head_dim]. + * \param[in] dBias The gradient of the Bias tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. + * \param[out] dQKV The gradient of the QKV tensor. + * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen Max sequence length used for computing, + * it may be >= max(cu_seqlens). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQKV, + const NVTETensor cu_seqlens, + size_t max_seqlen, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute dot product attention with packed KV input. + * + * Computes: + * - P = Q * K.T + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * V.T + * + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] Bias The Bias tensor. + * \param[in,out] S The S tensor. + * \param[out] O The output O tensor. + * \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] is_training Whether this is in training mode or inference. + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor Bias, + NVTETensor S, + NVTETensor O, + NVTETensorPack* Aux_Output_Tensors, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, + bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute the backward of the dot product attention with packed KV input. + * + * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. + * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. + * \param[in] dBias The gradient of the Bias tensor. + * \param[in] O The O tensor from forward. + * \param[in] dO The gradient of the O tensor. + * \param[in] S The S tensor. + * \param[in,out] dP The gradient of the P tensor. + * \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. + * \param[out] dQ The gradient of the Q tensor. + * \param[out] dKV The gradient of the KV tensor. + * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] rng_state Seed and offset of CUDA random number generator. + * \param[in] max_seqlen_q Max sequence length used for computing for Q. + * it may be >= max(cu_seqlens_q). + * \param[in] max_seqlen_kv Max sequence length used for computing for KV. + * it may be >= max(cu_seqlens_kv). + * \param[in] attn_scale Scaling factor for Q * K.T. + * \param[in] dropout Dropout probability. + * \param[in] qkv_layout QKV tensor's layout. + * \param[in] bias_type Bias type. + * \param[in] attn_mask_type Attention mask type. + * \param[in] workspace Workspace tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_fused_attn_bwd_kvpacked( + const NVTETensor Q, + const NVTETensor KV, + const NVTETensor dBias, + const NVTETensor O, + const NVTETensor dO, + const NVTETensor S, + NVTETensor dP, + const NVTETensorPack* Aux_CTX_Tensors, + NVTETensor dQ, + NVTETensor dKV, + const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, + size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + NVTETensor workspace, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/include/transformer_engine/logging.h b/transformer_engine/common/include/transformer_engine/logging.h index 36fd614f59..d488274579 100644 --- a/transformer_engine/common/include/transformer_engine/logging.h +++ b/transformer_engine/common/include/transformer_engine/logging.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -39,10 +40,18 @@ inline void check_cublas_(cublasStatus_t status) { } } +inline void check_cudnn_(cudnnStatus_t status) { + if ( status != CUDNN_STATUS_SUCCESS ) { + NVTE_ERROR("CUDNN Error: " + std::string(cudnnGetErrorString(status))); + } +} + } // namespace #define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } #define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); } +#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); } + #endif // TRANSFORMER_ENGINE_LOGGING_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 0f17a4926a..72383c36bc 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -24,11 +24,12 @@ extern "C" { enum NVTEDType { kNVTEByte = 0, /*!< Byte */ kNVTEInt32 = 1, /*!< 32-bit integer */ - kNVTEFloat32 = 2, /*!< 32-bit float */ - kNVTEFloat16 = 3, /*!< 16-bit float (E5M10) */ - kNVTEBFloat16 = 4, /*!< 16-bit bfloat (E8M7) */ - kNVTEFloat8E4M3 = 5, /*!< 8-bit float (E4M3) */ - kNVTEFloat8E5M2 = 6, /*!< 8-bit float (E5M2) */ + kNVTEInt64 = 2, /*!< 32-bit integer */ + kNVTEFloat32 = 3, /*!< 32-bit float */ + kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ + kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ + kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ + kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ kNVTENumTypes /*!< Number of supported types */ }; @@ -129,6 +130,19 @@ float *nvte_tensor_scale(const NVTETensor tensor); */ float *nvte_tensor_scale_inv(const NVTETensor tensor); +struct NVTETensorPack { + static const int MAX_SIZE = 10; /*!< we expect <10 matrices in auxiliary outputs */ + NVTETensor tensors[MAX_SIZE]; /*!< wrappers to tensors, do not hold memory */ + size_t size = 0; /*!< actual size of the tensor pack, 0 <= size <= MAX_SIZE */ +}; + +/*! \brief Create NVTETensors in NVTETensorPack. + */ +void nvte_tensor_pack_create(NVTETensorPack* pack); + +/*! \brief Destroy NVTETensors in NVTETensorPack. + */ +void nvte_tensor_pack_destroy(NVTETensorPack* pack); #ifdef __cplusplus } // extern "C" @@ -146,11 +160,12 @@ namespace transformer_engine { enum class DType { kByte = 0, kInt32 = 1, - kFloat32 = 2, - kFloat16 = 3, - kBFloat16 = 4, - kFloat8E4M3 = 5, - kFloat8E5M2 = 6, + kInt64 = 2, + kFloat32 = 3, + kFloat16 = 4, + kBFloat16 = 5, + kFloat8E4M3 = 6, + kFloat8E5M2 = 7, kNumTypes }; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 679d1e93c4..708712ff9a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -133,3 +133,16 @@ float *nvte_tensor_scale_inv(const NVTETensor tensor) { "Tensor's inverse of scale must have Float32 type!"); return reinterpret_cast(t.scale_inv.dptr); } + +void nvte_tensor_pack_create(NVTETensorPack* pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); + } +} + +void nvte_tensor_pack_destroy(NVTETensorPack* pack) { + for (int i = 0; i < pack->MAX_SIZE; i++) { + auto *t = reinterpret_cast(pack->tensors[i]); + delete t; + } +} diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 271c70fcab..cc8b063245 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -14,7 +14,7 @@ with enum in transformer_engine.h """ TE_DType = { - torch.int8: tex.DType.kByte, + torch.uint8: tex.DType.kByte, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, diff --git a/transformer_engine/pytorch/cpp_extensions.py b/transformer_engine/pytorch/cpp_extensions.py index fae64445f0..1353f1513e 100644 --- a/transformer_engine/pytorch/cpp_extensions.py +++ b/transformer_engine/pytorch/cpp_extensions.py @@ -3,11 +3,735 @@ # See LICENSE for license information. """TE FP8 extensions and GEMMs""" -from typing import Optional, Tuple, Union +import math +from typing import Optional, Tuple, List, Union import torch import transformer_engine_extensions as tex from .constants import TE_DType +TORCH_DType = { + tex.DType.kFloat8E4M3: torch.uint8, + tex.DType.kFloat8E5M2: torch.uint8, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + tex.DType.kFloat32: torch.float32, + tex.DType.kInt32: torch.int32, +} + +def check_tensor(x: torch.Tensor): + """Check tensor properties.""" + assert (x.is_cuda and x.is_contiguous() + ), "Tensor should be a GPU tensor and contiguous." + +def check_qkv(qkv: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(qkv) + assert (qkv.dtype is dtype + and qkv.dim() == 4 + and qkv.shape[1] == 3 + ), """QKV should be in [total_seqs, 3, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_q(q: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(q) + assert (q.dtype is dtype + and q.dim() == 3 + ), """Q should be in [total_seqs, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_kv(kv: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(kv) + assert (kv.dtype is dtype + and kv.dim() == 4 + and kv.shape[1] == 2 + ), """KV should be in [total_seqs, 2, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_o(o: torch.Tensor, dtype: torch.dtype): + """Check tensor properties.""" + check_tensor(o) + assert (o.dtype is dtype + and o.dim() == 3 + ), """O and dO should be in [total_seqs, num_heads, head_dim] shape + and {dtype} dtype.""" + +def check_stats(stats: torch.Tensor, b: int, h: int, s: int): + """Check tensor properties.""" + check_tensor(stats) + assert (stats.dtype is torch.float32 + and stats.dim() == 4 + and stats.shape == torch.Size([b, h, s, 1]) + ), """M and ZInv should be in [batch_size, num_heads, max_seqlen_q, 1] + shape and float32 dtype.""" + +def check_cu_seqlens(cu_seqlens: torch.Tensor): + """Check tensor properties.""" + check_tensor(cu_seqlens) + assert (cu_seqlens.dtype is torch.int32 + and cu_seqlens.dim() == 1 + ), """cu_seqlens should be in [batch_size +1] shape and int32 dtype.""" + +def check_scalar(scalar: torch.Tensor): + """Check tensor properties.""" + check_tensor(scalar) + assert (scalar.dtype is torch.float32 + and scalar.dim() <= 1 + and scalar.numel() == 1 + ), "amax/scale/descale tensors should be scalars in float32 dtype." + +def check_rng_state(rng_state: torch.Tensor): + """Check tensor properties.""" + check_tensor(rng_state) + assert (rng_state.dtype is torch.int64 + and rng_state.numel() == 2 + ), "rng_state should be [seed, offset] and in int64 dtype." + +def fused_attn_fwd_qkvpacked( + is_training: bool, + max_seqlen: int, + cu_seqlens: torch.Tensor, + qkv: torch.Tensor, + qkv_dtype: tex.DType, + bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_o: torch.Tensor = None, + amax_s: torch.Tensor = None, + amax_o: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", + rng_gen: torch.Generator = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention FWD for packed QKV input. + + Parameters + ---------- + is_training: bool + if True, runs training and produces auxiliary tensors aux_ctx_tensors + for the backward; if False, runs inference and doesn't produce aux_ctx_tensors + max_seqlen: int + max sequence length for QKV, used for padding; may be larger than max(cu_seqlens) + cu_seqlens: torch.Tensor + accumulative sequence lengths for QKV; shape [batch_size + 1] + qkv: torch.Tensor + input tensor QKV; + shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + rng_gen: torch.Generator, default = None + random number generator; + if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + + Returns + ---------- + o: torch.Tensor + output tensor O, of the attention calculation; same data type as QKV; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors used for the backward; + if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] + if is_training is False, aux_ctx_tensors = [rng_state] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + rng_state: torch.Tensor + state of the random number generator; + [seed, offset], dtype uint64 + """ + + check_cu_seqlens(cu_seqlens) + b = cu_seqlens.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_qkv(qkv, qkv_type) + + total_seqs = qkv.size(0) + h = qkv.size(2) + d = qkv.size(3) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen <= 512) and (d == 64): + assert (qkv_layout == "qkv_interleaved" + and bias_type == "no_bias" + and attn_mask_type == "padding" + ), """The FP8 fused attention API currently only supports qkv_interleaved layout, + no_bias type, and padding attention mask type.""" + assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." + assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." + assert (q_scale_o is not None), "q_scale_o is required for the FP8 API." + assert (amax_s is not None), "amax_s is required for the FP8 API." + assert (amax_o is not None), "amax_o is required for the FP8 API." + check_scalar(d_scale_qkv) + check_scalar(q_scale_s) + check_scalar(q_scale_o) + check_scalar(amax_s) + check_scalar(amax_o) + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_fwd_qkvpacked( + b, max_seqlen, total_seqs, h, d, + is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens, + qkv, + qkv_dtype, + d_scale_qkv, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + bias, + rng_gen, + ) + + return output_tensors[0], output_tensors[1:] + + +def fused_attn_bwd_qkvpacked( + max_seqlen: int, + cu_seqlens: torch.Tensor, + qkv: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + qkv_dtype: tex.DType, + aux_ctx_tensors: List[torch.Tensor] = None, + d_bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, + d_scale_o: torch.Tensor = None, + d_scale_do: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_dp: torch.Tensor = None, + q_scale_dqkv: torch.Tensor = None, + amax_dp: torch.Tensor = None, + amax_dqkv: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention BWD for packed QKV input. + + Parameters + ---------- + max_seqlen: int + max sequence length for QKV, used for padding; may be larger than max(cu_seqlens_q) + cu_seqlens: torch.Tensor + accumulative sequence lengths for QKV; shape [batch_size + 1] + qkv: torch.Tensor + input tensor QKV; + shape [total_seqs, 3, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + o: torch.Tensor + input tensor O (output of forward); + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_o: torch.Tensor + input tensor dO (gradient of O); + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors of the forward pass when its is_training is True, + e.g. aux_ctx_tensors = [M, ZInv, rng_state] + d_bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_o: torch.Tensor, default = None + input tensor for the dequantization of O in FP8 computations + d_scale_do: torch.Tensor, default = None + input tensor for the dequantization of dO in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations + q_scale_dp: torch.Tensor, default = None + input tensor for the quantization of dP in FP8 computations, P = Q * K.T + q_scale_dqkv: torch.Tensor, default = None + input tensor for the quantization of dQKV in FP8 computations + amax_dp: torch.Tensor, default = None + output tensor, amax of dP, used by the next iteration in FP8 computations + amax_dqkv: torch.Tensor, default = None + output tensor, amax of dQKV, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + + Returns + ---------- + d_qkv: torch.Tensor + gradient tensor of QKV; same data type and shape as QKV + """ + + check_cu_seqlens(cu_seqlens) + b = cu_seqlens.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_qkv(qkv, qkv_type) + check_o(o, qkv_type) + check_o(d_o, qkv_type) + + total_seqs = qkv.size(0) + h = qkv.size(2) + d = qkv.size(3) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen <= 512) and d == 64: + assert (qkv_layout == "qkv_interleaved" + and bias_type == "no_bias" + and attn_mask_type == "padding" + ), """The FP8 fused attention API currently only supports qkv_interleaved layout, + no_bias type, and padding attention mask type.""" + assert (d_scale_qkv is not None), "d_scale_qkv is required for the FP8 API." + assert (d_scale_s is not None), "d_scale_s is required for the FP8 API." + assert (d_scale_o is not None), "d_scale_o is required for the FP8 API." + assert (d_scale_do is not None), "d_scale_do is required for the FP8 API." + assert (q_scale_s is not None), "q_scale_s is required for the FP8 API." + assert (q_scale_dp is not None), "q_scale_dp is required for the FP8 API." + assert (q_scale_dqkv is not None), "q_scale_dqkv is required for the FP8 API." + assert (amax_dp is not None), "amax_dp is required for the FP8 API." + assert (amax_dqkv is not None), "amax_dqkv is required for the FP8 API." + assert (len(aux_ctx_tensors) == 3 + ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for the FP8 API." + check_scalar(d_scale_qkv) + check_scalar(d_scale_s) + check_scalar(d_scale_o) + check_scalar(d_scale_do) + check_scalar(q_scale_s) + check_scalar(q_scale_dp) + check_scalar(q_scale_dqkv) + check_scalar(amax_dp) + check_scalar(amax_dqkv) + m, z_inv = aux_ctx_tensors[:2] + check_stats(m, b, h, max_seqlen) + check_stats(z_inv, b, h, max_seqlen) + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) and (max_seqlen <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_bwd_qkvpacked( + b, max_seqlen, total_seqs, h, d, + attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens, + qkv, o, d_o, + qkv_dtype, + aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + q_scale_s, q_scale_dp, q_scale_dqkv, + amax_dp, amax_dqkv, + d_bias, + ) + + return output_tensors[0] + + +def fused_attn_fwd_kvpacked( + is_training: bool, + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + kv: torch.Tensor, + qkv_dtype: tex.DType, + bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_o: torch.Tensor = None, + amax_s: torch.Tensor = None, + amax_o: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", + rng_gen: torch.Generator = None, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention FWD for packed KV input. + + Parameters + ---------- + is_training: bool + if True, runs training and produces auxiliary tensors aux_ctx_tensors + for the backward; if False, runs inference and doesn't produce aux_ctx_tensors + max_seqlen_q: int + max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max_seqlen_kv: int + max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + cu_seqlens_q: torch.Tensor + accumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + accumulative sequence lengths for KV; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + kv: torch.Tensor + packed input tensor KV; + shape [total_seqs_kv, 2, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) + q_scale_o: torch.Tensor, default = None + input tensor for the quantization of O in FP8 computations + amax_s: torch.Tensor, default = None + output tensor, amax of S, used by the next iteration in FP8 computations + amax_o: torch.Tensor, default = None + output tensor, amax of O, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + rng_gen: torch.Generator, default = None + random number generator; + if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + + Returns + ---------- + o: torch.Tensor + output tensor O, of the attention calculation; same data type as QKV; + shape [total_seqs, num_heads, head_dim], where total_seqs = cu_seqlens[-1] + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors used for the backward; + if is_training is True, aux_ctx_tensors = [M, ZInv, rng_state] + if is_training is False, aux_ctx_tensors = [rng_state] + M: torch.Tensor + max(Q*K.T) + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + ZInv: torch.Tensor + 1/sum(e^(x - max(x))), where x=Q*K.T + shape [batch_size, num_heads, max_seqlen, 1], dtype float32 + rng_state: torch.Tensor + state of the random number generator; + [seed, offset], dtype uint64 + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + b = cu_seqlens_q.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_q(q, qkv_type) + check_kv(kv, qkv_type) + + assert (q.size(1) == kv.size(2) + and q.size(2) == kv.size(3) + ), "Q and KV must have the same num_heads and head_dim." + total_seqs_q = q.size(0) + total_seqs_kv = kv.size(0) + h = q.size(1) + d = q.size(2) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ + and (d == 64): + assert False, "The FP8 fused attention API currently only supports packed QKV input." + + # BF16/FP16 fused attention API from fmha_v2 + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q > 512) and (max_seqlen_kv > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + # BF16/FP16 fused attention API from fmha_v1 apex + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_fwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, + is_training, attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens_q, cu_seqlens_kv, + q, kv, + qkv_dtype, + d_scale_qkv, + q_scale_s, + q_scale_o, + amax_s, + amax_o, + bias, + rng_gen, + ) + + return output_tensors[0], output_tensors[1:] + + +def fused_attn_bwd_kvpacked( + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + q: torch.Tensor, + kv: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + qkv_dtype: tex.DType, + aux_ctx_tensors: List[torch.Tensor] = None, + d_bias: torch.Tensor = None, + d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, + d_scale_o: torch.Tensor = None, + d_scale_do: torch.Tensor = None, + q_scale_s: torch.Tensor = None, + q_scale_dp: torch.Tensor = None, + q_scale_dqkv: torch.Tensor = None, + amax_dp: torch.Tensor = None, + amax_dqkv: torch.Tensor = None, + attn_scale: float = None, + dropout: float = 0.0, + set_zero: bool = True, + qkv_layout: str = "qkv_interleaved", + bias_type: str = "no_bias", + attn_mask_type: str = "padding", +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused Attention BWD for packed KV input. + + Parameters + ---------- + max_seqlen_q: int + max sequence length for Q, used for padding; may be larger than max(cu_seqlens_q) + max_seqlen_kv: int + max sequence length for KV, used for padding; may be larger than max(cu_seqlens_kv) + cu_seqlens_q: torch.Tensor + accumulative sequence lengths for Q; shape [batch_size + 1] + cu_seqlens_kv: torch.Tensor + accumulative sequence lengths for KV; shape [batch_size + 1] + q: torch.Tensor + input tensor Q; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + kv: torch.Tensor + packed input tensor KV; + shape [total_seqs_kv, 2, num_heads, head_dim], + where total_seqs_kv = cu_seqlens_kv[-1] + o: torch.Tensor + input tensor O (output of forward); + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_o: torch.Tensor + input tensor dO (gradient of O); + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + qkv_dtype: tex.DType + data type of QKV; in tex.DType, not torch.dtype + aux_ctx_tensors: List[torch.Tensor] + auxiliary output tensors of the forward pass when its is_training is True, + e.g. aux_ctx_tensors = [M, ZInv, rng_state] + bias: torch.Tensor, default = None + input tensor Bias; + shape [total_seqs_q, num_heads, head_dim], where total_seqs_q = cu_seqlens_q[-1] + d_scale_qkv: torch.Tensor, default = None + input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) + d_scale_o: torch.Tensor, default = None + input tensor for the dequantization of O in FP8 computations + d_scale_do: torch.Tensor, default = None + input tensor for the dequantization of dO in FP8 computations + q_scale_s: torch.Tensor, default = None + input tensor for the quantization of S in FP8 computations + q_scale_dp: torch.Tensor, default = None + input tensor for the quantization of dP in FP8 computations, P = Q * K.T + q_scale_dqkv: torch.Tensor, default = None + input tensor for the quantization of dQKV in FP8 computations + amax_dp: torch.Tensor, default = None + output tensor, amax of dP, used by the next iteration in FP8 computations, + P = Q * K.T + amax_dqkv: torch.Tensor, default = None + output tensor, amax of dQKV, used by the next iteration in FP8 computations + attn_scale: float, default = None + if not None, use attn_scale as the attention scale for Q*K.T BMM; + if None, use 1.0/sqrt(head_dim) as the default + dropout: float, default = 0.0 + dropout probability, 0.0 means no dropout, 1.0 means no output; + dropout must be 0.0 if is_training is False + set_zero: bool, default = True + if True, initializes the output tensor O to zero using the mha_fill method; + if False, doesn't initialize O after its allocation + qkv_layout: str, default = "qkv_interleaved" + layout of QKV; {"qkv_interleaved", "kv_interleaved", "not_interleaved"} + bias_type: str, default = "no_bias" + type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias"} + attn_mask_type: str, default = "padding" + type of the attention mask; {"padding", "causal", "no_mask"} + + Returns + ---------- + d_q: torch.Tensor + gradient tensor of Q; same data type and shape as Q + d_kv: torch.Tensor + gradient tensor of KV; same data type and shape as KV + """ + + check_cu_seqlens(cu_seqlens_q) + check_cu_seqlens(cu_seqlens_kv) + assert (cu_seqlens_q.numel() == cu_seqlens_kv.numel() + ), "cu_seqlens_q and cu_seqlens_kv must have the same length." + b = cu_seqlens_q.numel() - 1 + qkv_type = TORCH_DType[qkv_dtype] + check_q(q, qkv_type) + check_kv(kv, qkv_type) + check_o(o, qkv_type) + check_o(d_o, qkv_type) + + assert (q.size(1) == kv.size(2) + and q.size(2) == kv.size(3) + ), "Q and KV must have the same num_heads and head_dim." + total_seqs_q = q.size(0) + total_seqs_kv = q.size(0) + h = q.size(1) + d = q.size(2) + + if attn_scale is None: + attn_scale = 1.0 / math.sqrt(d) + + assert (len(aux_ctx_tensors) >= 1 + ), "aux_ctx_tensors must contain rng_state as its last element." + rng_state = aux_ctx_tensors[-1] + check_rng_state(rng_state) + + # FP8 fused attention API + if (qkv_type is torch.uint8) and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512) \ + and d == 64: + assert False, "The FP8 fused attention API currently only supports packed QKV input." + + ############### BF16/FP16 fused attention API from fmha_v2 ################ + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q > 512) and (max_seqlen_kv > 512): + # add BF/FP16 support for >512 sequence length + assert False, "The BF16/FP16 support for >512 sequence length is coming!" + + ############### BF16/FP16 fused attention API from fmha_v1 apex ################ + elif (qkv_type is torch.bfloat16 or qkv_type is torch.float16) \ + and (max_seqlen_q <= 512) and (max_seqlen_kv <= 512): + # add BF/FP16 support for <=512 sequence length + assert False, "The BF16/FP16 support for <=512 sequence length is coming!" + + else: + assert False, "No support for this dtype and max_seqlen combination." + + # execute kernel + output_tensors = tex.fused_attn_bwd_kvpacked( + b, max_seqlen_q, max_seqlen_kv, total_seqs_q, total_seqs_kv, h, d, + attn_scale, dropout, set_zero, qkv_layout, bias_type, attn_mask_type, + cu_seqlens_q, cu_seqlens_kv, + q, kv, o, d_o, + qkv_dtype, + aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + q_scale_s, q_scale_dp, q_scale_dqkv, + amax_dp, amax_dqkv, + d_bias, + ) + + return output_tensors def fp8_gemm( A: torch.Tensor, @@ -233,9 +957,9 @@ def fp8_cast_transpose_fused( return_outputs = False if cast_out is None or transpose_out is None: - cast_out = torch.empty_like(inp, dtype=torch.int8) + cast_out = torch.empty_like(inp, dtype=torch.uint8) transpose_out = torch.empty( - inp.shape[1], inp.shape[0], device="cuda", dtype=torch.int8 + inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 ) return_outputs = True diff --git a/transformer_engine/pytorch/csrc/common.cu b/transformer_engine/pytorch/csrc/common.cu index 2146118382..1d20607940 100644 --- a/transformer_engine/pytorch/csrc/common.cu +++ b/transformer_engine/pytorch/csrc/common.cu @@ -88,6 +88,19 @@ size_t product(const std::vector &shape) { } +at::Tensor allocateSpace(const std::vector& shape, + const transformer_engine::DType type, + bool init_to_zeros) { + std::vector shape_int64(shape.begin(), shape.end()); + c10::IntArrayRef ar_shape(shape_int64); + if (init_to_zeros) { + return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); + } else { + return at::empty(ar_shape, at::CUDA(GetATenDType(type))); + } +} + + at::Tensor allocateSpace(const NVTEShape &shape, const transformer_engine::DType type, bool init_to_zeros) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index f6c9898601..1d59fc7c43 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -15,9 +15,15 @@ #include #include #include +#include #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -101,6 +107,12 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { return transformer_engine::DType::kBFloat16; case at::kBool: return transformer_engine::DType::kByte; + case torch::kByte: + return transformer_engine::DType::kByte; + case torch::kInt32: + return transformer_engine::DType::kInt32; + case torch::kInt64: + return transformer_engine::DType::kInt64; default: NVTE_ERROR("Invalid type"); } @@ -141,6 +153,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, size_t product(const std::vector &shape); +at::Tensor allocateSpace(const std::vector& shape, + const transformer_engine::DType type, + bool init_to_zeros); at::Tensor allocateSpace(const NVTEShape &shape, const transformer_engine::DType type, diff --git a/transformer_engine/pytorch/csrc/extensions.cu b/transformer_engine/pytorch/csrc/extensions.cu index 23330efbf0..75d4abd031 100644 --- a/transformer_engine/pytorch/csrc/extensions.cu +++ b/transformer_engine/pytorch/csrc/extensions.cu @@ -9,6 +9,742 @@ #include "comm_gemm_overlap.h" #endif // NVTE_WITH_USERBUFFERS +constexpr int block_size = 512; +constexpr int ctas_per_sm = 4; + +// convert QKV layout to enum +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout) { + if (qkv_layout == "not_interleaved") { + return NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED; + } else if (qkv_layout == "qkv_interleaved") { + return NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED; + } else if (qkv_layout == "kv_interleaved") { + return NVTE_QKV_Layout::NVTE_KV_INTERLEAVED; + } else { + NVTE_ERROR("Invalid QKV layout. \n"); + } +} + +// convert bias type to enum +NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { + if (bias_type == "no_bias") { + return NVTE_Bias_Type::NVTE_NO_BIAS; + } else if (bias_type == "pre_scale_bias") { + return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; + } else if (bias_type == "post_scale_bias") { + return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; + } else { + NVTE_ERROR("Invalid bias type. \n"); + } +} + +// convert attn mask type to enum +NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { + if (mask_type == "padding") { + return NVTE_Mask_Type::NVTE_PADDING_MASK; + } else if (mask_type == "causal") { + return NVTE_Mask_Type::NVTE_CAUSAL_MASK; + } else if (mask_type == "no_mask") { + return NVTE_Mask_Type::NVTE_NO_MASK; + } else { + NVTE_ERROR("Invalid attention mask type. \n"); + } +} + +// fast zero-fills of tensors +template +__global__ void __launch_bounds__(block_size) mha_fill_kernel(scalar_t* out_tensor, + const int32_t* const start_row, + const size_t num_rows) { + size_t row_stride = gridDim.y * blockDim.x; + size_t row_index = blockIdx.x + static_cast(start_row[0]); + size_t col_index = blockIdx.y * blockDim.x + threadIdx.x; + while (row_index < num_rows) { + out_tensor[row_index*row_stride + col_index] = 0; + row_index += gridDim.x; + } +} + +// fast zero-fills of tensors +void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { + auto max_tokens = self.size(0); + auto self_2d = self.view({max_tokens, -1}); + auto fcd_size = self_2d.size(1); + TORCH_CHECK(self.is_contiguous(), "input not contiguous"); + TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); + const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); + uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); + dim3 dim_grid(num_blk_x, num_blk_y); + dim3 dim_block(block_size); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + self_2d.scalar_type(), "mha_fill", [&]() { + mha_fill_kernel<<>>( + self_2d.data_ptr(), + static_cast(start_index.data_ptr()), + max_tokens); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +// extract seed and offset from PhiloxCudaState +__global__ void unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + if (arg.captured_) { + rng_state_ptr[0] = static_cast(*arg.seed_.ptr); + rng_state_ptr[1] = static_cast( + *(arg.offset_.ptr) + static_cast(arg.offset_intragraph_)); + } else { + rng_state_ptr[0] = static_cast(arg.seed_.val); + rng_state_ptr[1] = static_cast(arg.offset_.val); + } +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state( + at::CUDAGeneratorImpl* gen, + size_t max_seq_len, + size_t threads_per_cta) { + at::PhiloxCudaState philox_args; + size_t elts_per_thread = (max_seq_len * max_seq_len + threads_per_cta - 1)/threads_per_cta; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + +// fused attention FWD with packed QKV +std::vector fused_attn_fwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs), + static_cast(h), static_cast(d)}, options); + if (set_zero) { + mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (Bias.has_value()) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract random number generator seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + size_t threads_per_cta = 128; + at::PhiloxCudaState philox_args = init_philox_state(gen, max_seqlen, threads_per_cta); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + // if training, tensors are [M, ZInv] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + std::vector output_tensors; + output_tensors.push_back(O); + // nvte_aux_tensor_pack.size is 0 if inference + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + if (is_training) { + output_tensors.push_back(rng_state); + } + + // execute the kernel + nvte_fused_attn_fwd_qkvpacked( + te_QKV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens.data(), + te_rng_state.data(), + max_seqlen, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, M, ZInv, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed QKV +std::vector fused_attn_bwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias) { + using namespace transformer_engine; + + // create output tensor dQKV + at::Tensor dQKV = torch::empty_like(QKV); + if (set_zero) { + mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV, te_dBias; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, + nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), {total_seqs, 3, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (dBias.has_value()) { + auto bias_shape = dBias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = makeTransformerEngineTensor( + dBias.value().data_ptr(), shape, DType::kFloat32, + nullptr, nullptr, nullptr); + } + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward into NVTETensors + // aux_ctx_tensors are [M, ZInv, rng_state] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens; + te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_qkvpacked( + te_QKV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQKV.data(), + te_cu_seqlens.data(), + max_seqlen, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQKV}; +} + +// fused attention FWD with packed KV +std::vector fused_attn_fwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen) { + using namespace transformer_engine; + + // create output tensor O + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto O = torch::empty({static_cast(total_seqs_q), + static_cast(h), static_cast(d)}, options); + if (set_zero) { + mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + at::Tensor descale_S = torch::empty_like(scale_S.value()); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, amax_S.value().data_ptr(), + scale_S.value().data_ptr(), descale_S.data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (Bias.has_value()) { + auto bias_shape = Bias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // extract rng seed and offset + auto gen = at::get_generator_or_default( + rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); + size_t threads_per_cta = 128; + at::PhiloxCudaState philox_args = init_philox_state( + gen, max(max_seqlen_q, max_seqlen_kv), threads_per_cta); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + philox_args, static_cast(rng_state.data_ptr())); + auto te_rng_state = makeTransformerEngineTensor(rng_state); + + // create auxiliary output tensors + // if training, tensors are [M, ZInv] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace and auxiliary output tensors + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // output_tensors = [O, nvte_aux_tensor_pack.tensors, rng_state] + std::vector output_tensors; + output_tensors.push_back(O); + // nvte_aux_tensor_pack.size is 0 if inference + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + // allocate memory for nvte_aux_tensor_pack.tensors + auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + output_tensors.push_back(output_tensor); + tensor->data.dptr = output_tensor.data_ptr(); + } + if (is_training) { + output_tensors.push_back(rng_state); + } + + // execute the kernel + nvte_fused_attn_fwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_Bias.data(), + te_S.data(), + te_O.data(), + &nvte_aux_tensor_pack, + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + te_rng_state.data(), + max_seqlen_q, max_seqlen_kv, + is_training, attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers, but not allocated memory + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + // if training, [O, M, ZInv, rng_state]; if inference, [O] + return output_tensors; +} + +// fused attention BWD with packed KV +std::vector fused_attn_bwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias) { + using namespace transformer_engine; + + // create output tensors dQ and dKV + at::Tensor dQ = torch::empty_like(Q); + at::Tensor dKV = torch::empty_like(KV); + if (set_zero) { + mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } + + // construct NVTE tensors + TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV, te_dBias; + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // FP8 + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!scale_S.has_value()) || (!scale_dP.has_value()) + || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; + err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); + } + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, + scale_S.value().data_ptr(), descale_S.value().data_ptr()); + at::Tensor descale_dP = torch::empty_like(scale_dP.value()); + te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, + amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), + descale_dP.data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, qkv_type, + amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + // BF16 or FP16 + te_Q = makeTransformerEngineTensor(Q.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_KV = makeTransformerEngineTensor(KV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dO = makeTransformerEngineTensor(dO.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_S = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dP = makeTransformerEngineTensor(nullptr, {0}, + DType::kFloat32, nullptr, nullptr, nullptr); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), {total_seqs_q, h, d}, + qkv_type, nullptr, nullptr, nullptr); + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), {total_seqs_kv, 2, h, d}, + qkv_type, nullptr, nullptr, nullptr); + } else { + NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); + } + if (dBias.has_value()) { + auto bias_shape = dBias.value().sizes().vec(); + std::vector shape{bias_shape.begin(), bias_shape.end()}; + te_dBias = makeTransformerEngineTensor( + dBias.value().data_ptr(), shape, DType::kFloat32, + nullptr, nullptr, nullptr); + } + + // create cu_seqlens tensorwrappers + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), {b+1}, + DType::kInt32, nullptr, nullptr, nullptr); + + // convert strings to enums + NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); + NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); + NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); + + // convert auxiliary tensors from forward to NVTETensors + // aux_ctx_tensors are [M, ZInv, rng_state] + NVTETensorPack nvte_aux_tensor_pack; + nvte_tensor_pack_create(&nvte_aux_tensor_pack); + nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); + tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); + std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); + tensor->data.shape = std::vector(tmp.begin(), tmp.end()); + tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + } + + // create workspace + TensorWrapper workspace; + + // populate tensors with appropriate shapes and dtypes + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // allocate memory for workspace + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = makeTransformerEngineTensor( + workspace_data.data_ptr(), + workspace.shape(), workspace.dtype()); + + // execute kernel + nvte_fused_attn_bwd_kvpacked( + te_Q.data(), + te_KV.data(), + te_dBias.data(), + te_O.data(), + te_dO.data(), + te_S.data(), + te_dP.data(), + &nvte_aux_tensor_pack, + te_dQ.data(), + te_dKV.data(), + te_cu_seqlens_q.data(), + te_cu_seqlens_kv.data(), + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, + qkv_layout_enum, bias_type_enum, attn_mask_type_enum, + workspace.data(), + at::cuda::getCurrentCUDAStream()); + + // destroy tensor wrappers + nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + + return {dQ, dKV}; +} + void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, @@ -749,13 +1485,13 @@ at::Tensor cast_to_fp8(const at::Tensor &input, transformer_engine::DType otype ) { using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), scale.data_ptr(), scale_inv.data_ptr()); @@ -795,12 +1531,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input, transformer_engine::DType otype ) { using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); + auto input_shape = input.sizes().vec(); + std::vector shape{input_shape.begin(), input_shape.end()}; auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype, + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, scale_inv.data_ptr()); auto output_cu = makeTransformerEngineTensor(output); @@ -1066,6 +1802,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("te_gemm", &te_gemm, "CublasLt GEMM"); + m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed QKV"); + m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed QKV"); + m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, + "Fused Attention FP8/BF16/FP16 FWD with packed KV"); + m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, + "Fused Attention FP8/BF16/FP16 BWD with packed KV"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6be404226e..561ba417e6 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -5,7 +5,95 @@ ************************************************************************/ #include "common.h" - +#include "../common.h" + +NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout); + +NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type); + +NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type); + +std::vector fused_attn_fwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen); + +std::vector fused_attn_bwd_qkvpacked( + size_t b, size_t max_seqlen, size_t total_seqs, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens, + const at::Tensor QKV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias); + +std::vector fused_attn_fwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + bool is_training, float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const transformer_engine::DType qkv_type, + const c10::optional descale_QKV, + const c10::optional scale_S, + const c10::optional scale_O, + c10::optional amax_S, + c10::optional amax_O, + const c10::optional Bias, + const c10::optional rng_gen); + +std::vector fused_attn_bwd_kvpacked( + size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t total_seqs_q, size_t total_seqs_kv, + size_t h, size_t d, + float attn_scale, float p_dropout, bool set_zero, + std::string qkv_layout, std::string bias_type, std::string attn_mask_type, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, + const at::Tensor Q, + const at::Tensor KV, + const at::Tensor O, + const at::Tensor dO, + const transformer_engine::DType qkv_type, + const std::vector Aux_CTX_Tensors, + const c10::optional descale_QKV, + const c10::optional descale_S, + const c10::optional descale_O, + const c10::optional descale_dO, + const c10::optional scale_S, + const c10::optional scale_dP, + const c10::optional scale_dQKV, + c10::optional amax_dP, + c10::optional amax_dQKV, + const c10::optional dBias); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, diff --git a/transformer_engine/pytorch/module.py b/transformer_engine/pytorch/module.py index 3e0a868047..07805088b2 100644 --- a/transformer_engine/pytorch/module.py +++ b/transformer_engine/pytorch/module.py @@ -102,7 +102,7 @@ def get_workspace() -> torch.Tensor: global _cublas_workspace if _cublas_workspace is None: _cublas_workspace = torch.empty( - get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda" + get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" ) return _cublas_workspace @@ -520,7 +520,7 @@ def set_fp8_weights(self) -> None: torch.empty( shape, device=torch.cuda.current_device(), - dtype=torch.int8, + dtype=torch.uint8, ), ) setattr( @@ -530,7 +530,7 @@ def set_fp8_weights(self) -> None: shape[1], shape[0], device=torch.cuda.current_device(), - dtype=torch.int8, + dtype=torch.uint8, ), )