diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7039d38cf5..b5b262baff 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,6 +70,7 @@ jobs: run: pip install . -v env: NVTE_FRAMEWORK: jax + MAX_JOBS: 1 - name: 'Sanity check' run: python tests/jax/test_sanity_import.py paddle: diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index cd68019c8f..fc5e27d0a4 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -16,13 +16,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download artifact - uses: actions/download-artifact@v4.1.7 + uses: actions/download-artifact@v4 with: name: "te_docs" path: "html" - name: Prepare for pages uses: actions/upload-pages-artifact@v1.0.7 with: + name: github-pages path: "html" deploy: needs: prepare diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4762cccee6..b6fadba1bd 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: cd docs make html - name: 'Upload docs' - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: te_docs path: docs/_build/html diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c2317c6509..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -40,6 +40,9 @@ jobs: || github.actor == 'vasunvidia' || github.actor == 'erhoo82' || github.actor == 'kocchop' + || github.actor == 'youngeunkwon0405' + || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 28444e84a9..809a0327d8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.13.0.dev0 +1.14.0.dev0 diff --git a/qa/L1_pytorch_mcore_integration/.gitignore b/qa/L1_pytorch_mcore_integration/.gitignore new file mode 100644 index 0000000000..46426003ca --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/.gitignore @@ -0,0 +1,2 @@ +Megatron-LM +vocab.json \ No newline at end of file diff --git a/qa/L1_pytorch_mcore_integration/merges.txt b/qa/L1_pytorch_mcore_integration/merges.txt new file mode 100644 index 0000000000..5e7f1fd949 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/merges.txt @@ -0,0 +1 @@ +#version: 0.2 diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh new file mode 100644 index 0000000000..b0aba17ef5 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Check whether FP8 is supported +DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') +if [[ ${DEVICE_ARCH} -ge 89 ]]; then + WITH_FP8=1 +fi + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Create mock vocab +VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json +printf "" > ${VOCAB_FILE} +printf "{" >> ${VOCAB_FILE} +printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} +seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} +printf "}" >> ${VOCAB_FILE} + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 128 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file ${VOCAB_FILE} +--merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt +--transformer-impl transformer_engine +${WITH_FP8:+--fp8-format hybrid} +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 45806e7022..ab6b6a5316 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,8 +10,7 @@ add_executable(test_operator test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu - test_layernorm.cu - test_rmsnorm.cu + test_normalization.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu deleted file mode 100644 index cdd8e7846c..0000000000 --- a/tests/cpp/operator/test_layernorm.cu +++ /dev/null @@ -1,302 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon) { - using compute_t = float; - for (size_t i = 0 ; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += current; - } - mu[i] = sum / H; - compute_t m = mu[i]; - sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current - m) * (current - m); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta, - OutputType *output, const float *mu, const float *rsigma, - const size_t N, const size_t H, - float *amax, float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0 ; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, - const float *mu, const float *rsigma, - const InputType *gamma, - InputType *data_grad, - InputType *gamma_grad, InputType *beta_grad, - const size_t N, const size_t H, - const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - std::vector dbeta(H, 0.f); - - for (size_t i = 0 ; i < N; ++i) { - // Reductions - compute_t mdy = 0, mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - dbeta[j] += dz; - mdy += dy; - mdyy += dy * y; - } - mdy /= H; - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - beta_grad[j] = static_cast(dbeta[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); - Tensor workspace, barrier, dgamma_part, dbeta_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&beta); - setRandomScale(&z); - fillUniform(&dz); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_mu = std::make_unique(N); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - std::unique_ptr ref_dbeta = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype()); - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - mu.to_cpu(); - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_mu.get(), - ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), - gamma.cpu_dptr(), - beta.cpu_dptr(), - ref_output.get(), - mu.cpu_dptr(), - rsigma.cpu_dptr(), - N, H, - &ref_amax, - ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - mu.cpu_dptr(), rsigma.cpu_dptr(), - gamma.cpu_dptr(), - ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), - N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - if (otype == DType::kFloat32) { - atol = 5e-7; - } - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 1e-4; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); - compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = {{2048, 12288}, - {768, 1024}, - {256, 65536}, - {128, 6144}, - {64, 2304}, - {229, 541}, // Primes 50, 100 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class LNTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(LNTestSuite, TestLN) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma); - ); - ); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - LNTestSuite, - ::testing::Combine( - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo& info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu new file mode 100644 index 0000000000..bd6ee96af8 --- /dev/null +++ b/tests/cpp/operator/test_normalization.cu @@ -0,0 +1,380 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RmsNorm"} +}; + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + compute_t current, m; + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +// For now, cudnn does static_cast(gamma + static_cast(1.0)) +// This will be changed in the future release +template +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){ + + using compute_t = float; + if constexpr (std::is_same_v || std::is_same_v){ + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } else { + if (use_cudnn){ + compute_t g = static_cast(0.f); + InputType gi = gamma; + if (zero_centered_gamma) { + gi = gi + static_cast(1.f); + } + g = static_cast(gi); + return g; + } else { + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + OutputType* output, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = static_cast(tmp * scale); + current_max = fmaxf(current_max, fabsf(tmp)); + } + } + *amax = current_max; +} + + +template +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, + const float *mu, const float *rsigma, + const InputType *gamma, + InputType *data_grad, + InputType *gamma_grad, InputType *beta_grad, + const size_t N, const size_t H, + const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + std::vector dgamma(H, 0.f); + std::vector dbeta(H, 0.f); + + for (size_t i = 0 ; i < N; ++i) { + // Reductions + auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; + compute_t mdy = 0, mdyy = 0; + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + dgamma[j] += y * dz; + if (norm_type == LayerNorm) { + dbeta[j] += dz; + mdy += dy; + } + mdyy += dy * y; + } + mdy /= H; + mdyy /= H; + + // Input grads + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + data_grad[i * H + j] = static_cast(dx); + } + } + + // Weight grads + for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); + if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, + NormType norm_type, bool use_cudnn) { + if (sizeof(InputType) < sizeof(OutputType)) { + GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; + return; + } + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || + (itype == DType::kFloat16 && otype == DType::kBFloat16)) { + GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; + return; + } + + Tensor input({ N, H }, itype); + Tensor z({ N, H }, otype); + Tensor gamma({ H }, wtype); + Tensor beta({ H }, wtype); + Tensor mu({ N }, DType::kFloat32); + Tensor rsigma({ N }, DType::kFloat32); + Tensor dz({ N, H }, wtype); + Tensor dx({ N, H }, itype); + Tensor dgamma({ H }, wtype); + Tensor dbeta({ H }, wtype); + Tensor workspace_fwd, workspace_bwd; + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + setRandomScale(&z); + fillUniform(&dz); + + std::unique_ptr ref_output = std::make_unique(N * H); + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_dx = std::make_unique(N * H); + std::unique_ptr ref_dgamma = std::make_unique(H); + std::unique_ptr ref_dbeta = std::make_unique(H); + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(true); + nvte_enable_cudnn_norm_bwd(true); + } + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(false); + nvte_enable_cudnn_norm_bwd(false); + } + + // Reference implementations + // use the GPU stats to tighten the tolerances + mu.to_cpu(); + rsigma.to_cpu(); + float ref_amax; + compute_ref_stats(norm_type, input.cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; + compute_ref_output(norm_type, input.cpu_dptr(), + gamma.cpu_dptr(), + beta.cpu_dptr(), + ref_output.get(), + mu.cpu_dptr(), + rsigma.cpu_dptr(), + N, H, + &ref_amax, + ref_scale, + zero_centered_gamma, + use_cudnn); + compute_ref_backward(norm_type, dz.cpu_dptr(), input.cpu_dptr(), + mu.cpu_dptr(), rsigma.cpu_dptr(), + gamma.cpu_dptr(), + ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), + N, H, zero_centered_gamma, + use_cudnn); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + if (isFp8Type(otype)) { + compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); + + auto [atol, rtol] = getTolerances(otype); + if (otype == DType::kFloat32) { + atol = 5e-7; + } + compareResults("output", z, ref_output.get(), atol, rtol); + + double atol_bwd = 5e-4; + double rtol_bwd = 5e-4; + compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); + compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); + compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); +} + +std::vector> test_cases = { + {71, 229}, + {29, 541}, + {768, 6144}, + {2048, 12288}, +}; + +} // namespace + +class NormTestSuite : public ::testing::TestWithParam, + bool>> {}; + +TEST_P(NormTestSuite, TestNorm) { + using namespace transformer_engine; + using namespace test; + + const bool use_cudnn = std::get<0>(GetParam()); + const NormType norm_type = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const auto size = std::get<4>(GetParam()); + const bool zero_centered_gamma = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + NormTestSuite, + ::testing::Combine( + ::testing::Values(false), //TODO: enabling tests for cudnn backend + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; +std::string name = + backend + + normToString.at(std::get<1>(info.param)) + "_" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + std::to_string(std::get<4>(info.param).first) + "X" + + std::to_string(std::get<4>(info.param).second) + "X" + + std::to_string(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu deleted file mode 100644 index 0ec3a877e5..0000000000 --- a/tests/cpp/operator/test_rmsnorm.cu +++ /dev/null @@ -1,249 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H, - const double epsilon) { - using compute_t = float; - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current) * (current); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output, - const float *rsigma, const size_t N, const size_t H, float *amax, - float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = current * rsigma[i] * g; - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma, - const InputType *gamma, InputType *data_grad, InputType *gamma_grad, - const size_t N, const size_t H, const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - - for (size_t i = 0; i < N; ++i) { - // Reductions - compute_t mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - mdyy += dy * y; - } - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({N, H}, itype); - Tensor z({N, H}, otype); - Tensor gamma({H}, wtype); - Tensor rsigma({N}, DType::kFloat32); - Tensor dz({N, H}, wtype); - Tensor dx({N, H}, itype); - Tensor dgamma({H}, wtype); - Tensor workspace, barrier, dgamma_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&dz); - setRandomScale(&z); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), gamma.cpu_dptr(), ref_output.get(), - rsigma.cpu_dptr(), N, H, &ref_amax, ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - rsigma.cpu_dptr(), gamma.cpu_dptr(), ref_dx.get(), - ref_dgamma.get(), N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - atol = 1e-8; - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 5e-6; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = { - {2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409}, // Primes 40, 80 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class RMSNormTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(RMSNormTestSuite, TestRMSNorm) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma););); -} - -INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo &info) { - std::string name = - test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/tests/cpp/util/test_string.cpp b/tests/cpp/util/test_string.cpp index 531994aff8..14c1cc11f3 100644 --- a/tests/cpp/util/test_string.cpp +++ b/tests/cpp/util/test_string.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include +#include #include @@ -57,6 +58,12 @@ TEST(UtilTest, ToStringLike) { // to_string_like EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f); EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25); EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5); + + // Container types + EXPECT_EQ(to_string_like(std::vector{-3,1,-4}), "(-3,1,-4)"); + EXPECT_EQ(to_string_like(std::vector{"Accept", "no", "substitutes", ".", + "Buy", "N", "V", "IDIA"}), + "(Accept,no,substitutes,.,Buy,N,V,IDIA)"); } TEST(UtilTest, ConcatStringsTest) { // concat_strings @@ -88,6 +95,9 @@ TEST(UtilTest, ConcatStringsTest) { // concat_strings EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f); EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25); EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5); + + // Container types + EXPECT_EQ(concat_strings("vector ", std::vector{1,-2,3}), "vector (1,-2,3)"); } TEST(UtilTest, RegexReplaceTest) { // regex_replace diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7ef0d68474..e194a228d2 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -341,8 +341,9 @@ def ref_func(query, kv, mask): @pytest.mark.parametrize( "data_shape", [ - pytest.param([2, 512, 12, 128], id="2-512-12-128"), - pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ], ) @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) @@ -423,6 +424,12 @@ def impl_test_contex_parallel_attn( qkv_format = get_qkv_format(qkv_layout) batch, seqlen, num_head, hidden = data_shape + + # Scale the sequence length by 2*CP so its never too small as we scale up test. + # 2*CP is used since we split into two CP groups for load balancing. + seqlen = seqlen * cp_size * 2 + data_shape = batch, seqlen, num_head, hidden + num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ce46a72189..f81fbae1fe 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): @pytest.mark.parametrize( - "comm_type,fp8", + "comm_type, fp8, connections", [ - ("AG", False), - ("RS", False), - ("RS", True), + ("AG", False, 1), + ("RS", False, 1), + ("RS", True, 1), + ("AG", False, 8), + ("RS", False, 8), + ("RS", True, 8), + ], + ids=[ + "ALL-GATHER - BF16 - 1 connections", + "REDUCE-SCATTER - BF16 - 1 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], - ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], ) -def test_bulk_overlaps(comm_type, fp8): +def test_bulk_overlaps(comm_type, fp8, connections): """ Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + if connections == 8: + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip( + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" + " 9.0 (HOPPER ARCH)." + ) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) @pytest.mark.parametrize( diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2d863b3bba..3ddfab055c 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -42,7 +42,7 @@ def run_dpa_with_cp( "causal", "no_mask", ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if kernel_backend == "FusedAttention" and qkv_format == "thd": + if qkv_format == "thd": if "causal" in config.attn_mask_type: config.attn_mask_type = "padding_causal" else: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 29829ac4ac..fd2832c1d4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -293,7 +293,7 @@ def test_fp8_scale_update( ) # Check that scaling factors match expected - w_amax_ref = max(w_vals[: step + 2]) + w_amax_ref = max(w_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1]) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) @@ -1362,6 +1362,166 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (16, 16), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_output: bool, + fp8_grad_input: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + fp8 = fp8_output or fp8_grad_input + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.SwiGLU(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3784689f9a..84fc567cd3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -30,14 +30,14 @@ endif() # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR - "${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") message(FATAL_ERROR - "Could not find cuDNN frontend API. " + "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " "Try running 'git submodule update --init --recursive' " "within the Transformer Engine source.") endif() -include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -61,15 +61,17 @@ list(APPEND transformer_engine_SOURCES activation/swiglu.cu fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp + fused_attn/thd_utils.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu - layer_norm/ln_api.cpp - layer_norm/ln_bwd_semi_cuda_kernel.cu - layer_norm/ln_fwd_cuda_kernel.cu + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/layernorm/ln_bwd_semi_cuda_kernel.cu + normalization/layernorm/ln_fwd_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_api.cpp + normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - rmsnorm/rmsnorm_api.cpp - rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu util/padding.cu util/cuda_driver.cpp diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index a663385b68..c6f0f870ff 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -90,6 +90,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + cudaRuntimeGetVersion(&runtime_version); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + } else { + _comm_launch_event = 0; + } } CommOverlapCore::~CommOverlapCore() { @@ -97,6 +114,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); + if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm); + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 26843d8107..91667958e7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#if (CUDART_VERSION >= 12030) +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3 +#else +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 +#endif + +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ + ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; + #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } } @@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, - cudaStream_t stream) { + cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); } template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream) { + const int elements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, - comm, stream); + comm, stream, comm_launch_event); } template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 57e68afce0..75655ef691 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -228,21 +229,26 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0); + const int elements, communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/fused_attn/thd_utils.cu b/transformer_engine/common/fused_attn/thd_utils.cu new file mode 100644 index 0000000000..a1e353be71 --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.cu @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../cudnn_utils.h" +#include "thd_utils.h" + +namespace transformer_engine { +namespace fused_attn { + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h new file mode 100644 index 0000000000..c9a62727e6 --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -0,0 +1,249 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ + +#include +#include + +namespace transformer_engine { +namespace fused_attn { + +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search an array for a target value + **************************************************************************************************/ + +__forceinline__ __device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank); + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token); + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int total_tokens) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * total_tokens / 2 + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + idx = row * total_tokens + col + seq_len; + half_idx = row * total_tokens / 2 + col; + } + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int lse_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + } + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 17ecca5ff0..1d5d192a39 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -62,7 +62,7 @@ class CommOverlapCore { bool _ubuf_scale_inv_initialized{false}; std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, diff --git a/transformer_engine/common/include/transformer_engine/layer_norm.h b/transformer_engine/common/include/transformer_engine/layer_norm.h deleted file mode 100644 index 3bb4d47f29..0000000000 --- a/transformer_engine/common/include/transformer_engine/layer_norm.h +++ /dev/null @@ -1,159 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file layer_norm.h - * \brief LayerNorm functions. - */ - -#ifndef TRANSFORMER_ENGINE_LAYER_NORM_H_ -#define TRANSFORMER_ENGINE_LAYER_NORM_H_ - -#include "transformer_engine.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief Compute LayerNorm on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute LayerNorm with zero-centered gamma on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm with zero-centered gamma. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TRANSFORMER_ENGINE_LAYER_NORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/rmsnorm.h b/transformer_engine/common/include/transformer_engine/normalization.h similarity index 55% rename from transformer_engine/common/include/transformer_engine/rmsnorm.h rename to transformer_engine/common/include/transformer_engine/normalization.h index dc995e3c24..de9644792b 100644 --- a/transformer_engine/common/include/transformer_engine/rmsnorm.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -4,12 +4,12 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file rmsnorm.h - * \brief RMSNorm functions. +/*! \file normalization.h + * \brief LayerNorm and RMSNorm functions. */ -#ifndef TRANSFORMER_ENGINE_RMSNORM_H_ -#define TRANSFORMER_ENGINE_RMSNORM_H_ +#ifndef TRANSFORMER_ENGINE_NORMALIZATION_H_ +#define TRANSFORMER_ENGINE_NORMALIZATION_H_ #include "transformer_engine.h" @@ -17,41 +17,73 @@ extern "C" { #endif -/*! \brief Compute RMSNorm on the input. +/*! \brief Compute LayerNorm on the input. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}\gamma - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta * @f] * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. * \param[in] gamma Gamma tensor of shape [H]. + * \param[in] beta Beta tensor of shape [H]. * \param[in] epsilon Value added to denominator for numerical stability. * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. + * \param[out] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[out] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[out] workspace Workspace tensor. * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); + +/*! \brief Compute backward of LayerNorm. + * + * This function computes the gradient of function: + * @f[ + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta + * @f] + * else + * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. + * + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of these tensors to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[in] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] dbeta Gradient for beta tensor of shape [H]. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, - NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream); -/*! \brief Compute RMSNorm with zero-centered gamma on the input. +/*! \brief Compute RMSNorm. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) + * y = \frac{x}{RMS_\varepsilon(x)}\gamma * @f] * where * @f[ @@ -68,14 +100,14 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep * \param[in,out] z Output tensor of shape [N, H]. * \param[out] rsigma Reciprocal of the root mean square of the input * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, - NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); +void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, + NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); /*! \brief Compute backward of RMSNorm. * @@ -100,53 +132,25 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float * \param[in] gamma Gamma tensor of shape [H]. * \param[out] dx Output gradient of shape [N, H]. * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); -/*! \brief Compute backward of RMSNorm with zero-centered gamma. +/*! \brief Helper to enable cuDNN backend for normalization * - * This function computes the gradient of function: - * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} - * @f] - * with respect to \f$x\f$ and \f$gamma\f$. - * - * Calling this function with workspace, barrier, dgamma_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] bool Enable if True */ -void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, - const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_enable_cudnn_norm_fwd(bool enable); +void nvte_enable_cudnn_norm_bwd(bool enable); #ifdef __cplusplus } // extern "C" #endif -#endif // TRANSFORMER_ENGINE_RMSNORM_H_ +#endif // TRANSFORMER_ENGINE_NORMALIZATION_H_ diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h deleted file mode 100644 index 13543a10aa..0000000000 --- a/transformer_engine/common/layer_norm/ln.h +++ /dev/null @@ -1,239 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams { - size_t workspace_bytes; - size_t barrier_size; - - int multiprocessorCount; - cudaStream_t stream; - - Params params; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0), - rows(0), - cols(0), - x(nullptr), - mu(nullptr), - rs(nullptr), - gamma(nullptr), - workspace(nullptr), - barrier(nullptr), - zero_centered_gamma(false) {} - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - // Size of CTA group. - int ctas_per_row; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Whether gamma is centered around 0 - bool zero_centered_gamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - - // Scaling factor - void *scale; - - // AMax output - void *amax; - - // Inverse of scaling factor - void *scale_inv; - - // Whether to compute scale and amax - bool fp8_out; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase(), - dz(nullptr), - dbeta_part(nullptr), - dgamma_part(nullptr), - dx(nullptr), - dbeta(nullptr), - dgamma(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId {}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 0; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 1; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 2; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 3; -}; - -template -struct Type2Key { - constexpr static uint32_t Value = TypeId::Value << S; -}; - -template -struct WeightType2Key : public Type2Key {}; - -template -struct InputType2Key : public Type2Key {}; - -template -struct OutputType2Key : public Type2Key {}; - -template -struct ComputeType2Key : public Type2Key {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key { - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | - OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size) { - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp deleted file mode 100644 index 8a40450e59..0000000000 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ /dev/null @@ -1,457 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include -#include - -#include "../common.h" -#include "ln.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 -bf16 fp32 bf16 fp8 - -Remarks: -Output type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { -namespace layer_norm { - -using namespace transformer_engine; - -// Create registries and provide runtime versions of config hash functions. - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(DType dtype) { - if (dtype == DType::kFloat16) { - return TypeId::Value; - } else if (dtype == DType::kBFloat16) { - return TypeId::Value; - } else if (dtype == DType::kFloat32) { - return TypeId::Value; - } else if (dtype == DType::kFloat8E4M3) { - return TypeId::Value; - } else { - NVTE_ERROR("Type not supported."); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | - (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) && - is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) && - is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) && - is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) && - layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t product(const std::vector& shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void layernorm_fwd(const Tensor& x, // BxSxhidden_size - const Tensor& gamma, // hidden_size - const Tensor& beta, // hidden_size - const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - const auto itype = x.data.dtype; - const auto wtype = gamma.data.dtype; - const auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - const auto ctype = layer_norm::DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(hidden_size == cols); - - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(mu->data.shape == std::vector{rows}); - NVTE_CHECK(mu->data.dtype == ctype); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - layer_norm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - layer_norm::FwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu->data.dptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = beta.data.dptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, - const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, - Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(mu.data.dtype == ctype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - auto rows = x.data.shape[0]; - auto cols = x.data.shape[1]; - - auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(mu.data.shape[0] == rows); - NVTE_CHECK(mu.data.shape == rsigma.data.shape); - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - NVTE_CHECK(dbeta->data.shape == gamma.data.shape); - NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - layer_norm::BwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu.data.dptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = dbeta->data.dptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = dbeta_part->data.dptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - NVTE_CHECK(dbeta_part->data.dptr == nullptr); - - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - dbeta_part->data.dtype = ctype; - dbeta_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(dbeta_part->data.dptr != nullptr); - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - NVTE_CHECK(dbeta_part->data.dtype == ctype); - NVTE_CHECK(dbeta_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} -} // namespace transformer_engine - -void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 17f1256910..0000000000 --- a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,345 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_bwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu deleted file mode 100644 index 0c85f4aeb7..0000000000 --- a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu +++ /dev/null @@ -1,413 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_fwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp new file mode 100644 index 0000000000..5b6beb66b1 --- /dev/null +++ b/transformer_engine/common/normalization/common.cpp @@ -0,0 +1,445 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* #include */ + +#include "common.h" + +#include +#include +#include +#include +#include + +#include "transformer_engine/normalization.h" + +/* + +Supported Type combinations: + +input compute weights output +======================================= +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 +bf16 fp32 bf16 fp8 + +Remarks: +Output type = Weight type +Compute always in FP32 + +*/ + +namespace transformer_engine { +namespace normalization { + +TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, + bool zero_centered_gamma, bool is_tuned) { + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | + (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | + (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | + (uint32_t(zero_centered_gamma) << 16); + return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); +} + +template +TeNormalizationPlan::TeNormalizationPlan( + NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, + DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma, const bool is_tuned) + : _is_layernorm(NormType == NVTE_Norm_Type::LayerNorm) { + _launch_params.multiprocessorCount = sm_count; + + auto& kernel_params = _launch_params.params; + kernel_params.rows = batch_size; + kernel_params.cols = hidden_size; + kernel_params.zero_centered_gamma = zero_centered_gamma; + if constexpr (std::is_same_v) { + kernel_params.fp8_out = is_fp8_dtype(otype); + } + // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those + auto key = + get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned); + _kernel = KernelRegistry::getKernel(key); + + this->_build(); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.z = z->data.dptr; + kernel_params.epsilon = *reinterpret_cast(eps_dptr); + kernel_params.amax = z->amax.dptr; + kernel_params.scale = z->scale.dptr; + kernel_params.scale_inv = z->scale_inv.dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.beta = beta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Backward normalization should not call the forward execute function!"); +} + +template +void TeNormalizationPlan::_build() { + _kernel(_launch_params, true); + _launch_params.alignWorkspace(); +} + +template +std::vector TeNormalizationPlan::getWorkspaceShape() const { + return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; +} + +template +void TeNormalizationPlan::_set_workspace() { + if (_launch_params.getTotalWorkspaceBytes() > 0) { + auto workspace_dptr = reinterpret_cast(_launch_params.params.workspace); + + if (_launch_params.barrier_bytes > 0) { + _launch_params.params.barrier = + reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); + cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, + _launch_params.stream); + } + if constexpr (std::is_same_v) { + _launch_params.params.dgamma_part = + workspace_dptr + _launch_params.workspace_bytes + _launch_params.barrier_bytes; + if (_is_layernorm) { + _launch_params.params.dbeta_part = + reinterpret_cast(_launch_params.params.dgamma_part) + + _launch_params.dgamma_part_bytes; + } + } + } +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Forward normalization should not call the backward execute function!"); +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.dx = dx_dptr; + kernel_params.dz = dz_dptr; + kernel_params.dgamma = dgamma_dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.dbeta = dbeta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, DType ctype, + const size_t batch_size, const size_t hidden_size, + const size_t sm_count, + const bool zero_centered_gamma) + : _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) { + static_assert(CUDNN_FRONTEND_VERSION >= 10601, + "CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); + + namespace fe = cudnn_frontend; + + _scalar_dptr = std::make_unique(typeToSize(wtype)); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); + + _handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + + _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) + .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + + if (cudnnGetVersion() >= 90400) _graph.set_sm_count(sm_count); + + const auto batch_dim = static_cast(batch_size); + const auto hidden_dim = static_cast(hidden_size); + + // Create graph tensors + _x = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(itype))); + + _gamma_zero = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("gamma_zero") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + if (zero_centered_gamma) { + _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(wtype)) + .set_is_pass_by_value(true)); + auto centered_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); + _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); + } else { + _gamma = _gamma_zero; + } + + // Create graph computation nodes + if (NormStage == NVTE_Norm_Stage::Forward) { + _eps = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("epsilon") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + if (NormType == NVTE_Norm_Type::LayerNorm) { + _beta = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + auto norm_options = fe::graph::Layernorm_attributes() + .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options); + std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]); + _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else if (NormType == NVTE_Norm_Type::RMSNorm) { + auto norm_options = fe::graph::Rmsnorm_attributes() + .set_forward_phase(fe::NormFwdPhase_t::TRAINING) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm(_x, _gamma, norm_options); + std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]); + } + + _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + + const auto ZDtype = _fp8_out ? ctype : otype; + _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype)); + + if (_fp8_out) { + // create a scale node + _z_scale = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("z_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + auto z_scale_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); + + _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + + // create an amax reduction node + _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(get_cudnn_fe_dtype(ctype))); + _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + } + } else { + _dz = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("dz") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})); + _rsigma = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("inv_var") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + _mean = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("mean") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + if (NormType == NVTE_Norm_Type::LayerNorm) { + auto norm_options = fe::graph::Layernorm_backward_attributes() + .set_saved_mean_and_inv_variance(_mean, _rsigma) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm_backward(_dz, _x, _gamma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + _dbeta->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } else { + auto norm_options = + fe::graph::Rmsnorm_backward_attributes().has_dbias(false).set_compute_data_type( + get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm_backward(_dz, _x, _gamma, _rsigma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); + } + _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } + // Build the graph + this->_build(); +} + +void CudnnNormalizationPlan::_build() { + NVTE_CHECK(_graph.validate().is_good()); + NVTE_CHECK(_graph.build_operation_graph(_handle).is_good()); + NVTE_CHECK(_graph + .create_execution_plans( + {cudnn_frontend::HeurMode_t::A, cudnn_frontend::HeurMode_t::FALLBACK}) + .is_good()); + NVTE_CHECK(_graph.check_support(_handle).is_good()); + NVTE_CHECK( + _graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); +} + +std::vector CudnnNormalizationPlan::getWorkspaceShape() const { + return {static_cast(_graph.get_workspace_size())}; +} + +void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, + void* mean_dptr, void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}}; + + // layernorm should have valid mean_dptr and beta_dptr + if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}}); + + if (_zero_centered) + _variant_pack.insert( + {{_scalar_offset, reinterpret_cast(_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + if (_fp8_out) + _variant_pack.insert( + {{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}}); + else + _variant_pack.insert({{_z, z->data.dptr}}); + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); + if (_fp8_out) update_tensor_scale_inv(z, stream); +} + +void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, + void* rsigma_dptr, void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = { + {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; + + if (_zero_centered) + _variant_pack.insert({{_scalar_offset, reinterpret_cast(this->_scalar_dptr.get())}, + {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + // layernorm should have valid mean_dptr and beta_dptr + if (mean_dptr && dbeta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_dbeta, dbeta_dptr}}); + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); +} + +NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) { + const DType ctype = DType::kFloat32; + bool is_tuned = is_aligned && (batch_size % 4 == 0); + auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, + zero_centered_gamma, is_tuned); + + auto it = normalizationPlanMap.find(key); + if (it != normalizationPlanMap.end()) { + return it->second.get(); + } + + std::unique_ptr plan; + if (NormBackend == NVTE_Norm_Backend::Cudnn) { + plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, + batch_size, hidden_size, sm_count, + zero_centered_gamma); + } else if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } + normalizationPlanMap.insert({key, std::move(plan)}); + return normalizationPlanMap[key].get(); +} + +bool& _cudnn_norm_fwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN"); + return flag; +} + +bool& _cudnn_norm_bwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_BWD_USE_CUDNN"); + return flag; +} + +bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } +bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } + +} // namespace normalization +} // namespace transformer_engine + +void nvte_enable_cudnn_norm_fwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_fwd); + transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable; +} + +void nvte_enable_cudnn_norm_bwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); + transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; +} diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h new file mode 100644 index 0000000000..d1d56d5cc9 --- /dev/null +++ b/transformer_engine/common/normalization/common.h @@ -0,0 +1,381 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../cudnn_utils.h" +#include "../util/system.h" + +namespace transformer_engine { + +namespace normalization { + +namespace fe = cudnn_frontend; + +template +struct LaunchParams { + size_t workspace_bytes = 0; + size_t barrier_bytes = 0; + size_t dgamma_part_bytes = 0; + int multiprocessorCount; + cudaStream_t stream; + + KernelParamsType params; + + size_t getTotalWorkspaceBytes(const bool _is_layernorm = true) const { + return (workspace_bytes + barrier_bytes + size_t(_is_layernorm + 1) * dgamma_part_bytes); + } + void alignWorkspace(size_t alignment = 16) { + workspace_bytes = DIVUP(workspace_bytes, alignment) * alignment; + barrier_bytes = DIVUP(barrier_bytes, alignment) * alignment; + dgamma_part_bytes = DIVUP(dgamma_part_bytes, alignment) * alignment; + } +}; + +struct KernelParamsBase { + KernelParamsBase() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + workspace(nullptr), + barrier(nullptr), + zero_centered_gamma(false) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + // Size of CTA group. + int ctas_per_row; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void* x; + void* mu; + void* rs; + void* gamma; + + // Multi-CTA workspace in gmem. + void* workspace; + + // Multi-CTA sync barriers in gmem. + int* barrier; + + // Whether gamma is centered around 0 + bool zero_centered_gamma; +}; + +struct ForwardKernelParams : public KernelParamsBase { + ForwardKernelParams() + : KernelParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} + + // Output of LN FWD. + void* z; + void* beta; + float epsilon; + + // Scaling factor + void* scale; + int scale_byte_size; + + // Inverse of scaling factor + void* scale_inv; + + // AMax output + void* amax; + int amax_byte_size; + + // Whether to compute scale and amax + bool fp8_out; +}; + +struct BackwardKernelParams : public KernelParamsBase { + BackwardKernelParams() + : KernelParamsBase(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + + // Input: gradient wrt. LN FWD output. + void* dz; + + // Workspace for Wgrad pre-reduction. + void* dbeta_part; + void* dgamma_part; + + // Output: Dgrad. + void* dx; + // Output: Wgrad. + void* dbeta; + void* dgamma; +}; + +enum class NVTE_Norm_Backend { Te, Cudnn }; +enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; +enum class NVTE_Norm_Stage { Forward, Backward }; + +using TupleKeyType = std::tuple; +struct TupleHash { + size_t operator()(const TupleKeyType& t) const { + // Generate a hash for a tuple by combining the hashes of its entries + // See: https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine + size_t seed = 0; + std::hash hasher; + seed ^= hasher(std::get<0>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<1>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<2>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } +}; + +TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, + bool zero_centered_gamma, bool is_tuned); + +template +class TeNormalizationRegistry { + private: + using Function = std::function&, const bool)>; + std::unordered_map tuned_function_map; + std::unordered_map> general_function_map; + + TeNormalizationRegistry() = default; + + static TeNormalizationRegistry& getInstance() { + static TeNormalizationRegistry registry; + return registry; + } + + public: + static int registerFunction(TupleKeyType key, + void (*func)(LaunchParams&, const bool)) { + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) + getInstance().tuned_function_map.emplace(key, Function(func)); + else + getInstance().general_function_map[general_key].emplace(hidden_size, Function(func)); + return 0; + } + + static Function getKernel(TupleKeyType key) { + auto& instance = getInstance(); + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) { + auto it = instance.tuned_function_map.find(key); + if (it != instance.tuned_function_map.end()) return it->second; + } + if (instance.general_function_map.count(general_key) == 0) { + NVTE_ERROR("Unavailable kernel for this normalization config."); + } + auto& general_func_map = instance.general_function_map.at(general_key); + auto func_iter = general_func_map.lower_bound(hidden_size); + if (func_iter == general_func_map.end()) { + return general_func_map.rbegin()->second; // Hidden size is too big, need to use multi-CTA + } else { + return func_iter->second; + } + } + + TeNormalizationRegistry(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry& operator=(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry(TeNormalizationRegistry&&) = delete; + TeNormalizationRegistry& operator=(TeNormalizationRegistry&&) = delete; +}; + +class NormalizationPlanBase { + public: + virtual ~NormalizationPlanBase() = default; + virtual std::vector getWorkspaceShape() const = 0; + + virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) = 0; + + virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) = 0; + + private: + virtual void _build() = 0; +}; + +template +class TeNormalizationPlan : public NormalizationPlanBase { + public: + TeNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned); + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _set_workspace(); + void _build(); + + using KernelRegistry = TeNormalizationRegistry; + LaunchParams _launch_params; + std::function&, const bool)> _kernel; + + const bool _is_layernorm; +}; + +class CudnnNormalizationPlan : public NormalizationPlanBase { + public: + CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, DType ctype, const size_t batch_size, + const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma); + + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _build() override; + + const bool _zero_centered, _fp8_out; + std::unique_ptr _scalar_dptr; + // FWD + std::shared_ptr _x, _gamma_zero, _scalar_offset, _gamma, _beta, + _eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8; + // BWD + std::shared_ptr _dz, _dx, _dgamma, _dbeta; + + fe::graph::Graph _graph; + std::unordered_map, void*> _variant_pack; + cudnnHandle_t _handle; +}; + +class NormalizationPlanRegistry { + public: + static NormalizationPlanRegistry& getInstance() { + static thread_local NormalizationPlanRegistry instance; + return instance; + } + + NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend, + NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, + const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, + const bool is_aligned); + + private: + NormalizationPlanRegistry() {} + NormalizationPlanRegistry(const NormalizationPlanRegistry&) = delete; + NormalizationPlanRegistry& operator=(const NormalizationPlanRegistry&) = delete; + + std::unordered_map, TupleHash> + normalizationPlanMap; +}; + +using byte = uint8_t; +using int32 = int32_t; +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; + +template +struct TypeToDType; + +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kBFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E4M3; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E5M2; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kInt32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kByte; +}; + +#define IS_TUNED(x) (strcmp(#x, "tuned") == 0 ? 1 : 0) + +// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those +#define REGISTER_NORM_BASE(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, \ + CTYPE, FUNC_NAME) \ + static int \ + register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \ + TeNormalizationRegistry::registerFunction( \ + (get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \ + (TypeToDType::value), (TypeToDType::value), \ + (TypeToDType::value), (TypeToDType::value), 0, HIDDEN_SIZE, \ + 0, IS_TUNED(LAUNCH_TYPE))), \ + FUNC_NAME) + +// For FP8 only +void ComputeScaleInv(void* scale, void* scale_inv); + +// Alignment check +template +bool is_ptr_aligned(const Args*... ptrs) { + return ((reinterpret_cast(ptrs) % Alignment == 0) && ...); +} + +bool use_cudnn_norm_fwd(); +bool use_cudnn_norm_bwd(); + +} // namespace normalization + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/layer_norm/ln_kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h similarity index 89% rename from transformer_engine/common/layer_norm/ln_kernel_traits.h rename to transformer_engine/common/normalization/kernel_traits.h index a72726c325..0f8fea3f0b 100644 --- a/transformer_engine/common/layer_norm/ln_kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -4,16 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ #include "../common.h" #include "../utils.cuh" -//////////////////////////////////////////////////////////////////////////////////////////////////// - namespace transformer_engine { -namespace layer_norm { +namespace normalization { + template struct Kernel_traits_base { @@ -28,8 +27,6 @@ struct Kernel_traits_base { enum { THREADS_PER_WARP = 32 }; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - template + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" + +namespace transformer_engine { + +using namespace normalization; + +void layernorm_fwd(const Tensor& x, // BxSxhidden_size + const Tensor& gamma, // hidden_size + const Tensor& beta, // hidden_size + const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(gamma.data.shape == beta.data.shape); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(mu->data.dtype == DType::kFloat32); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_fwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, + mu->data.dptr, rsigma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + } + return; +} + +void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, + const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, + Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(mu.data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma.data.dtype == mu.data.dtype); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(mu.data.shape[0] == x.data.shape[0]); + NVTE_CHECK(mu.data.shape == rsigma.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + NVTE_CHECK(dbeta->data.shape == gamma.data.shape); + NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, + dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} +} // namespace transformer_engine + +void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size + const NVTETensor gamma, // hidden_size + const NVTETensor beta, // hidden_size + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_fwd); + using namespace transformer_engine; + layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + *reinterpret_cast(beta), epsilon, reinterpret_cast(z), + reinterpret_cast(mu), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_bwd); + using namespace transformer_engine; + layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(mu), *reinterpret_cast(rsigma), + *reinterpret_cast(gamma), reinterpret_cast(dx), + reinterpret_cast(dgamma), reinterpret_cast(dbeta), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/layer_norm/ln_bwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index dbd0025244..44078a040b 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -7,16 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -119,8 +118,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; + mdy_local = Get<0>::of(result) * rn; + mdyy_local = Get<1>::of(result) * rn; Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; @@ -203,7 +202,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -323,7 +322,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -424,8 +423,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne // Reduce over row reduce_t result = reducer.allreduce({mdy, mdyy}, sum); - mdy = layer_norm::Get<0>::of(result) * rn; - mdyy = layer_norm::Get<1>::of(result) * rn; + mdy = Get<0>::of(result) * rn; + mdyy = Get<1>::of(result) * rn; // Compute dx #pragma unroll @@ -507,7 +506,7 @@ template __global__ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -573,7 +572,7 @@ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_gener } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..d6e15dfc30 --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -0,0 +1,331 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../../common.h" +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &ln_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &ln_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..e7fe7a201b --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -0,0 +1,395 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_tuned_kernel; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh similarity index 97% rename from transformer_engine/common/layer_norm/ln_fwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index bd3741d1d1..3ec5543c3a 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -10,15 +10,16 @@ #include #include -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { +namespace normalization { using namespace transformer_engine; template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) { +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -92,8 +93,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( stats_t s = stats.compute(xf, rn); - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); + compute_t mu = Get<0>::of(s); + compute_t m2 = Get<1>::of(s); if (bidn == 0 && warp_n == 0 && lane == 0) { mu_ptr[row] = mu; @@ -150,7 +151,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -315,7 +316,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp new file mode 100644 index 0000000000..f6e36ae3c9 --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -0,0 +1,166 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" +#include "transformer_engine/normalization.h" + +namespace transformer_engine { + +using namespace normalization; + +void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, + Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + NVTE_CHECK(x.data.shape.size() == 2); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + Tensor empty; + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_fwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + } + + return; +} + +void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, + Tensor *dx, Tensor *dgamma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + } + + Tensor empty; + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} + +} // namespace transformer_engine + +void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size + const NVTETensor gamma, // hidden_size + const float epsilon, NVTETensor z, NVTETensor rsigma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_fwd); + using namespace transformer_engine; + rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size + const NVTETensor x, // Nxhidden_size + const NVTETensor rsigma, // N, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_bwd); + using namespace transformer_engine; + rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(rsigma), *reinterpret_cast(gamma), + reinterpret_cast(dx), reinterpret_cast(dgamma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 92fd850baa..223ac7fd79 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -7,15 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -172,7 +172,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -276,7 +276,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -430,8 +430,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ template -__global__ __launch_bounds__( - WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) { +__global__ +__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel( + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -474,7 +475,7 @@ __global__ __launch_bounds__( } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..309075c1ec --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -0,0 +1,206 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &rmsnorm_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..73634fc2dd --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh similarity index 98% rename from transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index c435ae3744..5965ffdc5d 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -10,15 +10,15 @@ #include #include -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -143,7 +143,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -291,7 +291,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm.h b/transformer_engine/common/rmsnorm/rmsnorm.h deleted file mode 100644 index 8b4e1cf24e..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm.h +++ /dev/null @@ -1,89 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" -#include "../layer_norm/ln.h" - -namespace transformer_engine { -namespace rmsnorm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams : public transformer_engine::layer_norm::LaunchParams {}; -struct FwdParams : public transformer_engine::layer_norm::FwdParams {}; -struct BwdParams : public transformer_engine::layer_norm::BwdParams {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp deleted file mode 100644 index 9b143b2f85..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ /dev/null @@ -1,387 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "../common.h" -#include "rmsnorm.h" -#include "transformer_engine/rmsnorm.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp32 fp16 -fp32 fp32 fp32 bf16 -fp32 fp32 fp32 fp8 -fp16 fp32 fp16 fp8 -bf16 fp32 bf16 fp8 - -Remarks: -Input type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { - -namespace layer_norm { -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size); -} - -namespace rmsnorm { - -using namespace transformer_engine; - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && - is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && - BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -// //////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); -} - -} // namespace rmsnorm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, - Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, - Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - auto ctype = DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(hidden_size == cols); - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - rmsnorm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - rmsnorm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = nullptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, - Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, - const int multiprocessorCount, Tensor *workspace, Tensor *barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - - const auto rows = x.data.shape[0]; - const auto cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - rmsnorm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - rmsnorm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = nullptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = nullptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} - -} // namespace transformer_engine - -void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 3215a6a9d4..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,220 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_bwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = - rmsnorm::Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = - Kernel_traits_finalize; - - auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &rmsnorm_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu deleted file mode 100644 index 3c8e121540..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ /dev/null @@ -1,227 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_fwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h b/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h deleted file mode 100644 index 26d7da6400..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h +++ /dev/null @@ -1,42 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ - -#include "../common.h" -#include "../layer_norm/ln_kernel_traits.h" -#include "../utils.cuh" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace transformer_engine { -namespace rmsnorm { - -template < - uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_, - typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_, - typename Base = - layer_norm::Kernel_traits_finalize > -struct Kernel_traits_finalize : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template > -struct Kernel_traits : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ diff --git a/transformer_engine/common/util/string.h b/transformer_engine/common/util/string.h index c0a2aa1077..3b0db02809 100644 --- a/transformer_engine/common/util/string.h +++ b/transformer_engine/common/util/string.h @@ -13,15 +13,34 @@ namespace transformer_engine { -/*! \brief Convert to C-style or C++-style string */ +inline const std::string &to_string_like(const std::string &val) noexcept { return val; } + +constexpr const char *to_string_like(const char *val) noexcept { return val; } + +/* \brief Convert arithmetic type to string */ template ::value>::type> inline std::string to_string_like(const T &val) { return std::to_string(val); } -inline const std::string &to_string_like(const std::string &val) noexcept { return val; } - -constexpr const char *to_string_like(const char *val) noexcept { return val; } +/* \brief Convert container to string */ +template ::value>::type, + typename = decltype(std::declval().begin())> +inline std::string to_string_like(const T &container) { + std::string str; + str.reserve(1024); // Assume strings are <1 KB + str += "("; + bool first = true; + for (const auto &val : container) { + if (!first) { + str += ","; + } + str += to_string_like(val); + first = false; + } + str += ")"; + return str; +} /*! \brief Convert arguments to strings and concatenate */ template diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6c4f518189..6591861057 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -380,8 +380,6 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - wkspace_aval = ctx.avals_out[-1] - if is_ffi_enabled(): name = "te_fused_attn_forward_ffi" out = ffi.ffi_lowering(name)( @@ -433,6 +431,8 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + wkspace_aval = ctx.avals_out[-1] + opaque = transformer_engine_jax.pack_fused_attn_descriptor( input_batch, bias_batch, @@ -725,28 +725,6 @@ def lowering( """ Fused attention bwd lowering rules """ - operands = [ - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - ] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( @@ -761,33 +739,90 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - wkspace_aval = ctx.avals_out[-1] + if is_ffi_enabled(): + name = "te_fused_attn_backward_ffi" + out = ffi.ffi_lowering(name)( + ctx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + head_dim=head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type), + mask_type=int(config.attn_mask_type), + qkv_layout=int(config.qkv_layout), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=config.window_size[0], + window_size_right=config.window_size[1], + ) + else: + operands = [ + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + ] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + wkspace_aval = ctx.avals_out[-1] - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, - bias_batch, - q_max_seqlen, - kv_max_seqlen, - attn_heads, - num_gqa_groups, - bias_heads, - head_dim, - config.max_segments_per_seq, - wkspace_aval.size, - config.scaling_factor, - config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, - jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - config.is_training, - not FusedAttnHelper.is_non_deterministic_allowed(), - config.window_size[0], - config.window_size[1], - ) + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + config.max_segments_per_seq, + wkspace_aval.size, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + config.is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], + ) - out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index fd6cc09de9..69d7962b62 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -16,7 +16,6 @@ from jax.extend import ffi from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -82,7 +81,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -96,18 +95,15 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, _, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval @staticmethod @@ -151,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -160,9 +156,6 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta] operand_shapes = [x_shape, g_shape, b_shape] @@ -174,15 +167,9 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -198,7 +185,7 @@ def impl(x, gamma, beta, zero_centered_gamma, epsilon): to describe implementation """ assert LayerNormFwdPrimitive.inner_primitive is not None - out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind( + out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return out, mu, rsigma @@ -377,39 +364,25 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - True, - kwargs["zero_centered_gamma"], - kwargs["epsilon"], - get_backward_sm_margin(), - ) + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + True, + kwargs["zero_centered_gamma"], + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - dbeta_part_aval = dbeta_aval.update( - shape=dbeta_part_info[0], dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]) - ) return ( dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, - barrier_aval, - dgamma_part_aval, - dbeta_part_aval, ) @staticmethod @@ -417,9 +390,7 @@ def outer_abstract(*args, **kwargs): """ LayerNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = LayerNormBwdPrimitive.abstract( - *args, **kwargs - ) + dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval, dbeta_aval @staticmethod @@ -470,20 +441,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - dbeta_part_aval.shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - jax_dtype_to_te_dtype(dbeta_part_aval.dtype), zero_centered_gamma, epsilon, sm_margin, @@ -496,7 +461,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): @staticmethod def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): assert LayerNormBwdPrimitive.inner_primitive is not None - dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return dx, dgamma, dbeta @@ -630,7 +595,7 @@ def abstract(x_aval, gamma_aval, **kwargs): hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -644,18 +609,15 @@ def abstract(x_aval, gamma_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd outer primitive abstract """ - out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, rsigma_aval @staticmethod @@ -688,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -696,9 +658,6 @@ def lowering(ctx, x, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma] operand_shapes = [x_shape, g_shape] @@ -710,15 +669,9 @@ def lowering(ctx, x, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -734,7 +687,7 @@ def impl(x, gamma, epsilon): to describe implementation """ assert RmsNormFwdPrimitive.inner_primitive is not None - out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) + out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) return out, rsigma @staticmethod @@ -833,36 +786,28 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): dx_aval = core.raise_to_shaped(dz_aval) dgamma_aval = core.raise_to_shaped(gamma_aval) - wkspace_info, barrier_info, dgamma_part_info, _ = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - False, - False, - kwargs["epsilon"], - get_backward_sm_margin(), - ) + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + False, + False, + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval + return dx_aval, dgamma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) + dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval @staticmethod @@ -896,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -904,12 +849,6 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), - ir.RankedTensorType.get( - dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) - ), ] operands = [dz, rsigma, x, gamma] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] @@ -921,15 +860,9 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - (0,), # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -942,7 +875,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): @staticmethod def impl(dz, x, rsigma, gamma, epsilon): assert RmsNormBwdPrimitive.inner_primitive is not None - dx, dgamma, _, _, _ = RmsNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind( dz, x, rsigma, gamma, epsilon=epsilon ) return dx, dgamma @@ -1066,7 +999,7 @@ def abstract( assert gamma_aval.size == beta_aval.size - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size jax_dtype_to_te_dtype(x_aval.dtype), # in type @@ -1084,18 +1017,15 @@ def abstract( wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd (fp8 out) outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = LayerNormFwdFp8Primitive.abstract( + out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract( *args, **kwargs ) return out_aval, mu_aval, rsigma_aval, updated_amax_aval @@ -1158,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1168,9 +1098,6 @@ def lowering( ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta, amax, scale, scale_inv] operand_shapes = [ @@ -1189,15 +1116,9 @@ def lowering( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -1215,7 +1136,7 @@ def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, to describe implementation """ assert LayerNormFwdFp8Primitive.inner_primitive is not None - out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( + out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( x, gamma, beta, @@ -1394,7 +1315,7 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp rsigama_dtype = jnp.float32 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch_size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -1412,18 +1333,15 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd (fp8 out) outer primitive abstract """ - out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) return out_aval, rsigma_aval, amax_aval @staticmethod @@ -1476,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1485,9 +1403,6 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, amax, scale, scale_inv] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] @@ -1499,15 +1414,9 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -1525,7 +1434,7 @@ def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon): to describe implementation """ assert RmsNormFwdFp8Primitive.inner_primitive is not None - out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( + out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) return out, rsigma, amax diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index bf92c00de3..a12943f4c2 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -12,12 +12,13 @@ from jax import core, dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper -from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype +from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled from ..softmax import SoftmaxType @@ -133,32 +134,36 @@ def forward_lowering(name, ctx, logits, *, scale_factor): """ softmax_forward lowering rules """ - (i_aval,) = ctx.avals_in - i_type = ir.RankedTensorType(logits.type) - i_shape = i_type.shape - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, i_shape[:-3]) - pad_batch = batch - heads = i_shape[-3] - q_seqlen = i_shape[-2] - k_seqlen = i_shape[-1] - - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [logits] - operand_shapes = [i_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor, - ) + if is_ffi_enabled(): + ffi_name = name + "_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor) + else: + (i_aval,) = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + pad_batch = batch + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits] + operand_shapes = [i_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(i_aval.dtype), + scale_factor, + ) - out = custom_caller(name, args, opaque, False) + out = custom_caller(name, args, opaque, False) return out @@ -240,37 +245,41 @@ def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor): """ softmax_backward lowering rules """ - dz_aval, _ = ctx.avals_in - - dz_type = ir.RankedTensorType(dz.type) - dz_shape = dz_type.shape - - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, dz_shape[:-3]) - pad_batch = batch # unused - heads = dz_shape[-3] - q_seqlen = dz_shape[-2] - k_seqlen = dz_shape[-1] - - softmax_out_type = ir.RankedTensorType(softmax_out.type) - softmax_out_shape = softmax_out_type.shape - - out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)] - operands = [dz, softmax_out] - operand_shapes = [dz_shape, softmax_out_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(dz_aval.dtype), - scale_factor, - ) + if is_ffi_enabled(): + ffi_name = name + "_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor) + else: + dz_aval, _ = ctx.avals_in + + dz_type = ir.RankedTensorType(dz.type) + dz_shape = dz_type.shape + + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, dz_shape[:-3]) + pad_batch = batch # unused + heads = dz_shape[-3] + q_seqlen = dz_shape[-2] + k_seqlen = dz_shape[-1] + + softmax_out_type = ir.RankedTensorType(softmax_out.type) + softmax_out_shape = softmax_out_type.shape + + out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)] + operands = [dz, softmax_out] + operand_shapes = [dz_shape, softmax_out_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(dz_aval.dtype), + scale_factor, + ) - out = custom_caller(name, args, opaque, False) + out = custom_caller(name, args, opaque, False) return out @@ -577,36 +586,39 @@ def lowering(ctx, logits, mask, *, scale_factor): """ te_scaled_masked_softmax_forward lowering rules """ + if is_ffi_enabled(): + ffi_name = "te_scaled_masked_softmax_forward_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor) + else: + logits_aval, _ = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + mask_type = ir.RankedTensorType(mask.type) + mask_shape = mask_type.shape + pad_batch = reduce(operator.mul, mask_shape[:-3]) + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits, mask] + operand_shapes = [i_shape, mask_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(logits_aval.dtype), + scale_factor, + ) - logits_aval, _ = ctx.avals_in - i_type = ir.RankedTensorType(logits.type) - i_shape = i_type.shape - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, i_shape[:-3]) - heads = i_shape[-3] - q_seqlen = i_shape[-2] - k_seqlen = i_shape[-1] - - mask_type = ir.RankedTensorType(mask.type) - mask_shape = mask_type.shape - pad_batch = reduce(operator.mul, mask_shape[:-3]) - - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [logits, mask] - operand_shapes = [i_shape, mask_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(logits_aval.dtype), - scale_factor, - ) - - out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) + out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1ad0c9c51d..64f3c467b6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -81,25 +81,18 @@ struct CustomCallNormDescriptor { size_t batch_size; size_t hidden_size; size_t wkspace_size; - size_t barrier_size; - Shape dgamma_part_shape; - Shape dbeta_part_shape; DType x_dtype; DType w_dtype; DType wkspace_dtype; - DType barrier_dtype; - DType dgamma_part_dtype; - DType dbeta_part_dtype; bool zero_centered_gamma; float eps; int sm_margin; }; -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin); struct SoftmaxDescriptor { size_t batch_size; @@ -238,6 +231,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); + // Softmax void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, @@ -258,8 +253,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + // Attention +// Cudnn helpers +XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); + NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, @@ -289,8 +299,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -// Cudnn helpers -XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 9d5fb4f7b4..a2090bceba 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -264,8 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 2); - auto n = act_input_dims.back(); + auto m = static_cast(product(act_input_dims, 0, act_input_dims.size() - 2)); + auto n = static_cast(act_input_dims.back()); auto act_len = act_input_dims.end()[-2]; auto input_shape = std::vector{m, n}; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 3679b46ee5..4bde10fc46 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -185,6 +185,33 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } +#define FUSED_ATTN_IMPL_COMMON_BLOCK \ + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ + size_t num_segments = input_batch; \ + if (is_ragged) { \ + auto cudnn_runtime_version = cudnnGetVersion(); \ + if (cudnn_runtime_version >= 90300) { \ + num_segments = input_batch * max_segments_per_seq; \ + } else { \ + size_t runtime_num_segments_q = \ + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \ + size_t runtime_num_segments_kv = \ + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \ + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ + num_segments = runtime_num_segments_q; \ + } \ + } \ + std::vector seq_shape{num_segments + 1}; \ + auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \ + auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32); \ + auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32); \ + auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32); \ + auto workspace_tensor = \ + TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); \ + auto layout_group = nvte_get_qkv_layout_group(qkv_layout); + static void FusedAttnForwardImpl( cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, @@ -194,43 +221,16 @@ static void FusedAttnForwardImpl( float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - auto cudnn_runtime_version = cudnnGetVersion(); - if (cudnn_runtime_version >= 90300) { - num_segments = input_batch * max_segments_per_seq; - } else { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; - } auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); } - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto q_seq_offsets_tensor = - TensorWrapper(q_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - auto k_seq_offsets_tensor = - TensorWrapper(k_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -251,12 +251,7 @@ static void FusedAttnForwardImpl( bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, backend, softmax_aux); - /* cuDNN workspace */ - auto workspace_tensor = - TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); - /* Call the underlying NVTE API */ - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); @@ -304,7 +299,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD; /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; void *bias = buffers[3]; void *q_cu_seqlens = buffers[4]; void *kv_cu_seqlens = buffers[5]; @@ -319,16 +316,43 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *workspace = buffers[12]; FusedAttnForwardImpl( - stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, - k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch, - descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, - descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed, + output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch, + descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads, + descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); } +#define FUSED_ATTN_FFI_GET_ATTRS \ + size_t input_batch = get_attr_value(attrs, "input_batch"); \ + size_t bias_batch = get_attr_value(attrs, "bias_batch"); \ + size_t q_max_seqlen = get_attr_value(attrs, "q_max_seqlen"); \ + size_t kv_max_seqlen = get_attr_value(attrs, "kv_max_seqlen"); \ + size_t attn_heads = get_attr_value(attrs, "attn_heads"); \ + size_t num_gqa_groups = get_attr_value(attrs, "num_gqa_groups"); \ + size_t bias_heads = get_attr_value(attrs, "bias_heads"); \ + size_t head_dim = get_attr_value(attrs, "head_dim"); \ + size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ + auto window_size_left = get_attr_value(attrs, "window_size_left"); \ + auto window_size_right = get_attr_value(attrs, "window_size_right"); \ + float scaling_factor = get_attr_value(attrs, "scaling_factor"); \ + float dropout_probability = get_attr_value(attrs, "dropout_probability"); \ + NVTE_Bias_Type bias_type = \ + static_cast(get_attr_value(attrs, "bias_type")); \ + NVTE_Mask_Type mask_type = \ + static_cast(get_attr_value(attrs, "mask_type")); \ + NVTE_QKV_Layout qkv_layout = \ + static_cast(get_attr_value(attrs, "qkv_layout")); \ + bool is_training = get_attr_value(attrs, "is_training"); \ + bool deterministic = get_attr_value(attrs, "deterministic"); \ + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ + size_t wkspace_size = product(workspace_buf->dimensions()); \ + DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); \ + DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, @@ -336,37 +360,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty Buffer_Type seed_buf, Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) { - /* Descriptor data type conversion */ - size_t input_batch = get_attr_value(attrs, "input_batch"); - size_t bias_batch = get_attr_value(attrs, "bias_batch"); - size_t q_max_seqlen = get_attr_value(attrs, "q_max_seqlen"); - size_t kv_max_seqlen = get_attr_value(attrs, "kv_max_seqlen"); - size_t attn_heads = get_attr_value(attrs, "attn_heads"); - size_t num_gqa_groups = get_attr_value(attrs, "num_gqa_groups"); - size_t bias_heads = get_attr_value(attrs, "bias_heads"); - size_t head_dim = get_attr_value(attrs, "head_dim"); - size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); - auto window_size_left = get_attr_value(attrs, "window_size_left"); - auto window_size_right = get_attr_value(attrs, "window_size_right"); - - float scaling_factor = get_attr_value(attrs, "scaling_factor"); - float dropout_probability = get_attr_value(attrs, "dropout_probability"); - - NVTE_Bias_Type bias_type = - static_cast(get_attr_value(attrs, "bias_type")); - NVTE_Mask_Type mask_type = - static_cast(get_attr_value(attrs, "mask_type")); - NVTE_QKV_Layout qkv_layout = - static_cast(get_attr_value(attrs, "qkv_layout")); - - bool is_training = get_attr_value(attrs, "is_training"); - bool deterministic = get_attr_value(attrs, "deterministic"); - - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - - size_t wkspace_size = product(workspace_buf->dimensions()); - DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); - DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + FUSED_ATTN_FFI_GET_ATTRS; FusedAttnForwardImpl( stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), @@ -503,81 +497,23 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } -void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - auto qkv_layout = descriptor.qkv_layout; - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *softmax_aux = buffers[4]; - void *rng_state = buffers[5]; - void *output = buffers[6]; - void *doutput = buffers[7]; - void *q_cu_seqlens = buffers[8]; - void *kv_cu_seqlens = buffers[9]; - void *q_seq_offsets = is_ragged ? buffers[10] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[11] : nullptr; - - /* Output buffer from XLA */ - /* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */ - void *dbias = buffers[15]; - void *workspace = buffers[16]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto dtype = descriptor.dtype; - auto deterministic = descriptor.deterministic; - auto max_segments_per_seq = descriptor.max_segments_per_seq; - auto window_size_left = descriptor.window_size_left; - auto window_size_right = descriptor.window_size_right; +static void FusedAttnBackwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, + void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, + void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, + size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { + FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments - if (is_ragged) { - auto cudnn_runtime_version = cudnnGetVersion(); - if (cudnn_runtime_version >= 90300) { - num_segments = input_batch * max_segments_per_seq; - } else { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; - } - } - - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto q_seq_offsets_tensor = - TensorWrapper(q_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - auto k_seq_offsets_tensor = - TensorWrapper(k_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); @@ -593,21 +529,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); - /* cuDNN workspace */ - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - /* Call the underly NVTE API */ - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = buffers[0]; auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto dqkv = buffers[12]; - auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); + auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); + auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dqkv, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 @@ -618,19 +546,15 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto dq = buffers[12]; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv = buffers[13]; - auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); + auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); if (is_ragged) { cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dkv, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -642,20 +566,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; auto v_shape = k_shape; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq = buffers[12]; auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk = buffers[13]; auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv = buffers[14]; auto dv_tensor = TensorWrapper(dv, v_shape, dtype); if (is_ragged) { cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); @@ -679,5 +597,93 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, nvte_tensor_pack_destroy(&aux_input_tensors); } +void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + auto qkv_layout = descriptor.qkv_layout; + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + /* Input buffers from XLA */ + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; + void *bias = buffers[3]; + void *softmax_aux = buffers[4]; + void *rng_state = buffers[5]; + void *output = buffers[6]; + void *doutput = buffers[7]; + void *q_cu_seqlens = buffers[8]; + void *kv_cu_seqlens = buffers[9]; + void *q_seq_offsets = is_ragged ? buffers[10] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[11] : nullptr; + + /* Output buffer from XLA */ + void *dq = buffers[12]; + void *dk = buffers[13]; + void *dv = buffers[14]; + void *dbias = buffers[15]; + void *workspace = buffers[16]; + + FusedAttnBackwardImpl( + stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens, + q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch, + descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, + descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, + descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, + descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, + descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); +} + +Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + FusedAttnBackwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), + output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), + dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, + scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, + is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // softmax_aux + .Arg() // rng_state + .Arg() // output + .Arg() // doutput + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // dbias + .Ret() // workspace + .Attrs(), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 9bd9951916..845eb844e2 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -3,9 +3,9 @@ * * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/normalization.h" + #include "extensions.h" -#include "transformer_engine/layer_norm.h" -#include "transformer_engine/rmsnorm.h" namespace transformer_engine { namespace jax { @@ -25,40 +25,36 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; if (is_layer_norm) { auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, - num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { + // TODO(Phuong): Verify and remove this check NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size, - size_t barrier_size, bool zero_centered_gamma, float eps, void *input, - DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, - DType out_dtype, void *workspace, DType work_dtype, void *barrier, - DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, - float *scale_inv, int sm_margin, cudaStream_t stream) { + bool zero_centered_gamma, float eps, void *input, DType in_dtype, + void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, + void *workspace, DType work_dtype, void *mu, void *rsigma, float *amax, + float *scale, float *scale_inv, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; auto workspace_shape = std::vector{workspace_size}; - auto barrier_shape = std::vector{barrier_size}; auto is_layer_norm = (bias) ? true : false; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -71,23 +67,21 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); if (is_layer_norm) { auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm, - workspace_tensor.data(), barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); + rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } } @@ -96,20 +90,17 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf, Result_Type *output_buf, Result_Type *mu_buf, Result_Type *rsigma_buf, Result_Type *amax_out_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm, bool is_fp8) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm, bool is_fp8) { auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); auto *input = x_buf->untyped_data(); auto *weight = gamma_buf->untyped_data(); auto *output = (*output_buf)->untyped_data(); auto *rsigma = (*rsigma_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); void *bias = nullptr; void *mu = nullptr; @@ -135,17 +126,15 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); return ffi_with_cuda_error_check(); } @@ -154,11 +143,10 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm true // is_fp8 ); @@ -178,7 +166,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -187,15 +174,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, nullptr, // amax_buf nullptr, // scale_buf, nullptr, // scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm false // is_fp8 ); @@ -211,7 +197,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI, .Ret() // mu .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -221,14 +206,14 @@ Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_T Buffer_Type amax_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, nullptr, // mu_buf, - &rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf, - zero_centered_gamma, eps_, sm_margin_, + &rsigma_buf, &amax_out_buf, &wkspace_buf, zero_centered_gamma, + eps_, sm_margin_, false, // is_layer_norm true // is_fp8 ); @@ -246,7 +231,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -254,8 +238,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Result_Type output_buf, Result_Type rsigma_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, nullptr, // amax_buf, @@ -265,7 +249,7 @@ Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type nullptr, // mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false, // is_layer_norm false // is_fp8 ); @@ -279,7 +263,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI, .Ret() // output .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -303,50 +286,34 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - // initialize dBeta information here -- layernorm will modify but RMSnorm will not - std::vector dbeta_part_shape; if (is_layer_norm) { auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dummy_dgamma_part_tensor.data(), - dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + dbeta_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); - dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(), - nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); - - dbeta_part_shape = std::vector{0, 0}; + xgrad_tensor.data(), wgrad_tensor.data(), dummy_work_tensor.data(), num_sm, + zero_centered_gamma, nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), - std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), - std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, - size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape, bool zero_centered_gamma, float eps, void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, void *workspace, - DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, - void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, - DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin, - cudaStream_t stream) { + DType wkspace_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, + void *dbeta, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -368,28 +335,23 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto workspace_shape = std::vector{wkspace_size}; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto barrier_shape = std::vector{barrier_size}; - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); if (is_layer_norm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(), - stream, num_sm, workspace_tensor.data(), barrier_tensor.data()); + dbeta_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream, - num_sm, workspace_tensor.data(), barrier_tensor.data()); + xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm, + zero_centered_gamma, stream); } } @@ -397,15 +359,11 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu Buffer_Type *mu_buf, Buffer_Type *rsigma_buf, Buffer_Type *gamma_buf, Result_Type *xgrad_buf, Result_Type *wgrad_buf, Result_Type *dbeta_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); - auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type()); auto *ograd = dz_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); @@ -414,62 +372,37 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu auto *xgrad = (*xgrad_buf)->untyped_data(); auto *wgrad = (*wgrad_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); - auto *dgamma_part = (*dgamma_part_buf)->untyped_data(); void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; - auto dbeta_part_dtype = DType::kByte; if (is_layer_norm) { mu = (*mu_buf).untyped_data(); dbeta = (*dbeta_buf)->untyped_data(); - dbeta_part = (*dbeta_part_buf)->untyped_data(); - dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type()); } auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; - Shape dgamma_part_shape; - auto dgamma_part_dims = (*dgamma_part_buf)->dimensions(); - std::vector dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end()); - dgamma_part_shape.from_vector(dgamma_parts_dims_vector); - - Shape dbeta_part_shape; - if (is_layer_norm) { - auto dbeta_part_dims = (*dbeta_part_buf)->dimensions(); - std::vector dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end()); - dbeta_part_shape.from_vector(dbeta_parts_dims_vector); - } else { - dbeta_part_shape.from_vector({0, 0}); - } - float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); return ffi_with_cuda_error_check(); } Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - Result_Type dgamma_part_buf, Result_Type dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf, - &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf, - &dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_, - sm_margin_, + &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, + zero_centered_gamma, eps_, sm_margin_, true // is_layer_norm ); } @@ -486,9 +419,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, .Ret() // wgrad .Ret() // dbeta .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part - .Ret() // dbeta_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -497,15 +427,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, Result_Type dgamma_part_buf, bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, nullptr, // mu_buf &rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf, nullptr, // dbeta_buf, - &wkspace_buf, &barrier_buf, &dgamma_part_buf, - nullptr, // dbeta_part_buf, - zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false // is_layer_norm ); } @@ -520,8 +447,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI, .Ret() // xgrad .Ret() // wgrad .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -540,7 +465,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto *rsigma = buffers[8]; auto *amax_out = buffers[9]; auto *workspace = buffers[10]; - auto *barrier = buffers[11]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"); @@ -548,21 +472,18 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -573,7 +494,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto *mu = buffers[4]; auto *rsigma = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; float *amax = nullptr; float *scale = nullptr; @@ -583,20 +503,17 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -605,15 +522,9 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - auto dbeta_part_shape = desc.dbeta_part_shape; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; @@ -627,15 +538,10 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto *wgrad = buffers[6]; auto *dbeta = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; - auto *dgamma_part = buffers[10]; - auto *dbeta_part = buffers[11]; - - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -648,7 +554,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto *rsigma = buffers[6]; auto *amax_out = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive."); void *bias = nullptr; @@ -658,20 +563,17 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -680,7 +582,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto *output = buffers[2]; auto *rsigma = buffers[3]; auto *workspace = buffers[4]; - auto *barrier = buffers[5]; void *bias = nullptr; void *mu = nullptr; @@ -692,20 +593,17 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -716,36 +614,24 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto *xgrad = buffers[4]; auto *wgrad = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; - auto *dgamma_part = buffers[8]; void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - Shape dbeta_part_shape; - dbeta_part_shape.from_vector({0, 0}); auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..ccc6921f43 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -32,24 +32,17 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shap return PackOpaque(desc); } -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin) { CustomCallNormDescriptor desc{}; desc.batch_size = batch_size; desc.hidden_size = hidden_size; desc.wkspace_size = wkspace_size; - desc.barrier_size = barrier_size; - desc.dgamma_part_shape.from_vector(dgamma_part_shape); - desc.dbeta_part_shape.from_vector(dbeta_part_shape); desc.x_dtype = x_dtype; desc.w_dtype = w_dtype; desc.wkspace_dtype = wkspace_dtype; - desc.barrier_dtype = barrier_dtype; - desc.dgamma_part_dtype = dgamma_part_dtype; - desc.dbeta_part_dtype = dbeta_part_dtype; desc.zero_centered_gamma = zero_centered_gamma; desc.eps = eps; desc.sm_margin = sm_margin; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bf906f375e..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -68,14 +68,39 @@ pybind11::dict Registrations() { // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); + dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); + + // Softmax + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = + EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_masked_softmax_backward_ffi"] = + EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; @@ -83,6 +108,11 @@ pybind11::dict Registrations() { fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; + pybind11::dict fused_attn_backward_ffi; + fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler); + fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); + dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + return dict; } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 5e33098eab..d08368657e 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t auto shape = desc.shape.to_vector(); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); - auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); } +Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *scale = reinterpret_cast(scale_buf.untyped_data()); + auto *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + std::vector shape(input_dims.begin(), input_dims.end()); + auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); + auto output_tensor = TensorWrapper(output, shape, out_dtype); + + nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret(), // output + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index 3af32d1d84..f54ebefcb0 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/softmax.h" #include "extensions.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { @@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, dgrad_tensor.data(), desc.scale_factor, stream); } +#define SOFTMAX_COMMON_BLOCK(tensor_buf) \ + auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \ + auto tensor_dims = (tensor_buf).dimensions(); \ + auto tensor_ranks = tensor_dims.size(); \ + auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \ + auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \ + auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \ + auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \ + float scale_factor = static_cast(scale_factor_); + +#define SOFTMAX_FORWARD_COMMON_BLOCK \ + auto *input = input_buf.untyped_data(); \ + auto *output = output_buf->untyped_data(); \ + auto input_tensor = TensorWrapper(input, shape, dtype); \ + auto output_tensor = TensorWrapper(output, shape, dtype); + +Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type output_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type mask_buf, Result_Type output_buf, + double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + + // Mask would be casted to uint8_t + auto *mask = mask_buf.untyped_data(); + auto mask_dims = mask_buf.dimensions(); + auto padding_size = product(mask_dims, mask_dims.size() - 3); + auto mask_shape = std::vector{padding_size, 1, q_seqlen, k_seqlen}; + auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte); + + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(), + scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type output_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + auto shape = std::vector{batch_size * head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(), + scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // mask + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler, + ScaledUpperTriangMaskedSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +#define SOFTMAX_BACKWARD_COMMON_BLOCK \ + auto *grad_output = grad_output_buf.untyped_data(); \ + auto *softmax_output = softmax_output_buf.untyped_data(); \ + auto *dgrad = dgrad_buf->untyped_data(); \ + auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \ + auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \ + auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); + +Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, Result_Type dgrad_buf, + double scale_factor_) { + SOFTMAX_COMMON_BLOCK(grad_output_buf); + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_BACKWARD_COMMON_BLOCK; + nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(), + dgrad_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream, + Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, + Result_Type dgrad_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(grad_output_buf); + auto shape = std::vector{batch_size * head_dim, q_seqlen, k_seqlen}; + SOFTMAX_BACKWARD_COMMON_BLOCK; + nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(), + softmax_output_tensor.data(), + dgrad_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler, + ScaledUpperTriangMaskedSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index b82c0915e4..e5649bfe7c 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -4,6 +4,7 @@ """ Praxis Modules """ +from dataclasses import field from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union @@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index f2ac802f10..2ae212afb9 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -4,6 +4,7 @@ """ Praxis Modules related Transformer """ +from dataclasses import field from functools import partial from typing import Optional, Sequence, Tuple import warnings @@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 6ce250432a..9b7e3d767a 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -10,9 +10,8 @@ #include #include #include -#include +#include #include -#include #include #include #include diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 583cd0f47a..b35b4434db 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -353,24 +353,23 @@ std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); auto mu_cu = MakeNvteTensor(mu); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + // This call populates workspace tensor with the required config + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); return {ln_out, mu, rsigma}; } @@ -394,24 +393,23 @@ std::vector te_layernorm_fwd(const paddle::Tensor &input, auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); auto mu_cu = MakeNvteTensor(mu); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + // This call populates workspace tensor with the required config + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), num_sm - sm_margin, + zero_centered_gamma, input.stream()); return {ln_out, mu, rsigma}; } @@ -424,7 +422,7 @@ std::vector te_layernorm_bwd(const paddle::Tensor &dz, const pad auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + TensorWrapper workspace; auto dz_cu = MakeNvteTensor(dz); auto x_cu = MakeNvteTensor(x); @@ -438,25 +436,18 @@ std::vector te_layernorm_bwd(const paddle::Tensor &dz, const pad auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + num_sm - sm_margin, zero_centered_gamma, dz.stream()); // Alloc space for Tensors. auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + num_sm - sm_margin, zero_centered_gamma, dz.stream()); return {dx, dgamma, dbeta}; } @@ -477,24 +468,21 @@ std::vector te_rmsnorm_fwd(const paddle::Tensor &input, auto gamma_cu = MakeNvteTensor(weight); auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config - + // This call populates workspace tensor with the required config nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); return {ln_out, rsigma}; } @@ -521,23 +509,21 @@ std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; + TensorWrapper workspace; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - // This call populates workspace and barrier tensors with the required config + // This call populates workspace tensor with the required config nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); - // Fill workspace and barrier + // Fill workspace auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); // Actual call to fwd kernel nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); + workspace.data(), num_sm - sm_margin, zero_centered_gamma, input.stream()); return {ln_out, rsigma}; } @@ -550,7 +536,7 @@ std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl auto dx = paddle::empty_like(x, x.dtype(), x.place()); auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - TensorWrapper workspace, barrier, dgamma_part; + TensorWrapper workspace; auto dz_cu = MakeNvteTensor(dz); auto x_cu = MakeNvteTensor(x); @@ -563,21 +549,17 @@ std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl // This call populates tensors with the required config. nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); + dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, + dz.stream()); // Alloc space for Tensors. auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); // Actual call to bwd kernel. nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); + dgamma_cu.data(), workspace.data(), num_sm - sm_margin, zero_centered_gamma, + dz.stream()); return {dx, dgamma}; } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..8c529c58d0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2528,12 +2528,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + (*saved_tensors,) = ctx.saved_tensors + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] + cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -3577,11 +3578,12 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -4056,12 +4058,11 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + (*saved_tensors,) = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] + fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] + aux_ctx_tensors = saved_tensors[10:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type @@ -4308,14 +4309,6 @@ def attn_forward_func_with_cp( assert ( qkv_format != "sbhd" or use_fused_attention ), "FlashAttention does not support sbhd format!" - assert ( - qkv_format != "thd" - or not use_fused_attention - or attn_mask_type in ["padding", "padding_causal"] - ), ( - f"Context parallelism is not supported for {attn_mask_type} mask type and " - f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" - ) assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" @@ -7877,6 +7870,9 @@ def forward( ), f"Values have head_dim = {value_layer.shape[-1]}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" + if qkv_format is None: + qkv_format = self.qkv_format + if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: @@ -7903,9 +7899,6 @@ def forward( graph_safe_rng_available() ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - if qkv_format is None: - qkv_format = self.qkv_format - if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -7951,7 +7944,10 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads!" + ) assert qkv_format in [ "sbhd", "bshd", diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ddc3b67e9e..188c03b27c 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -16,6 +16,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_dswiglu_cast_transpose_fused", "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ) +def fp8_dswiglu_cast_transpose_fused( + grad_output: torch.Tensor, + inp: torch.Tensor, + *, + grad_input: torch.Tensor, + grad_input_transpose: torch.Tensor, + otype: tex.DType, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> None: + """Fused SwiGLU backward + FP8 cast + FP8 transpose""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + + # Launch kernel + return tex.fused_dswiglu_cast_transpose( + grad_output, + inp, + grad_input, + grad_input_transpose, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + otype, + **fp8_scales_offsets, + ) + + def fp8_multi_cast_transpose_fused( input_list: List[torch.Tensor], fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 175a7b0e90..82f58b1eda 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,11 +28,10 @@ #include #include #include -#include +#include #include #include #include -#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..3b49ece4a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -210,6 +210,12 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); + void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 8088a2b8f1..d03a10ced3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,8 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/fused_attn/thd_utils.h" #include "extensions.h" +using namespace transformer_engine::fused_attn; + constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -1359,64 +1362,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Binary search - **************************************************************************************************/ - -__forceinline__ __device__ int binary_search(int target, int *array, int len) { - int left = 1, right = len - 1; - while (left < right) { - int mid = (left + right) / 2; - if (array[mid] <= target) { - left = mid + 1; - } else { - right = mid; - } - } - return left - 1; -} - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, - int hidden_size_in_bytes, int half_idx, - int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; - int laneid = threadIdx.x % 32; - int num_warps = (blockDim.x * gridDim.x) / 32; - int num_total_tokens = cu_seqlens_s[batch]; - int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); - - for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { - int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); - - size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4 *cur_half_token = - reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - - offset_in_bytes = - (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; - float4 *cur_token = - reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); - - for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { - cur_half_token[idx] = cur_token[idx]; - } - } -} - at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx) { NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); @@ -1464,51 +1413,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template -__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int total_tokens) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - int num_total_tokens = cu_seqlens_s[batch]; - - for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, half_idx; - if constexpr (lse_packed) { - idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; - half_idx = head_id * total_tokens / 2 + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - - idx = row * total_tokens + col + seq_len; - half_idx = row * total_tokens / 2 + col; - } - - Functor::run(lse, half_lse, idx, half_idx); - } - } -} - -struct LseCorrectionFunctor { - __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, - size_t half_idx) { - double val = lse[idx]; - float val_per_step = half_lse[half_idx]; - double max_scale = max(val, val_per_step); - double min_scale = min(val, val_per_step); - lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); - } -}; - void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); @@ -1559,13 +1463,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st } } -struct ReadLseFunctor { - __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, - size_t half_idx) { - half_lse[half_idx] = lse[idx]; - } -}; - at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); @@ -1620,59 +1517,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template -__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, - float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int lse_seqlen) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); - } - __syncthreads(); - - int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; - int lane_id = threadIdx.x % tile_size; - int num_tiles = (blockDim.x * gridDim.x) / tile_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); - - for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, idx_per_step; - - if constexpr (lse_packed) { - idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * lse_seqlen + col + seq_len * only_second_half; - idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; - } - float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); - - idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx = (idx * num_heads + head_id) * dim_per_head; - idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; - dtype *cur_out = out + idx; - dtype *cur_out_per_step = out_per_step + idx_per_step; - - for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; - float4 data = reinterpret_cast(cur_out)[j]; - dtype *p_per_step = reinterpret_cast(&data_per_step); - dtype *p = reinterpret_cast(&data); - for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); - } - reinterpret_cast(cur_out)[j] = data; - } - } - } -} - template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, @@ -1773,87 +1617,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at * Support THD format for Context Parallel: Gradients correction in backward **************************************************************************************************/ -template -__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, - int batch, int hidden_size, int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - if constexpr (functor_idx < 2) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } else { - cu_seqlens_s[i] = cu_seqlens[i]; - } - } - __syncthreads(); - - int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; - int lane_id = threadIdx.x % group_size; - int num_groups = (blockDim.x * gridDim.x) / group_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size; - if constexpr (functor_idx < 2) { - grad_per_step = grad_per_step + offset / 2 * blockIdx.y; - } else { - grad_per_step = grad_per_step + offset * blockIdx.y; - } - grad = grad + offset * blockIdx.y; - - for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - - int token_offset; - bool is_first_half; - if constexpr (functor_idx < 2) { - token_offset = cu_seqlens_s[seq_id + functor_idx]; - is_first_half = (functor_idx == 0); - } else { - token_offset = 0; - int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); - } - - dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; - dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; - for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { - if (is_first_half) { - Functor_0::run(token, token_per_step, idx); - } else { - Functor_1::run(token, token_per_step, idx); - } - } - } -} - -struct EmptyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} -}; - -struct CopyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; - } -}; - -template -struct AddFunctor { - __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); - - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); - -#pragma unroll - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p_[i] += p[i]; - } - - reinterpret_cast(token)[idx] = d_; - } -}; - template static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens) { @@ -1945,31 +1708,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ -__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - int seqlen = cu_seqlens[i]; - // Currently we assume that each sequence length is divisible by (world_size*2) since we have - // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size * 2) == 0); - cu_seqlens_s[i] = seqlen / world_size; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - - for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; - } -} - at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 04274ae2ef..2574b84352 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -19,7 +19,7 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); auto dbeta = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -31,32 +31,21 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dbeta_cu = makeTransformerEngineTensor(dbeta); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); - dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(), - dbeta_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {dx, dgamma, dbeta}; } @@ -88,9 +77,6 @@ std::vector layernorm_fwd_fp8_noalloc( const auto &weight_ = weight.contiguous(); const auto &bias_ = bias.contiguous(); - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - // Tensor dimensions size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); @@ -113,24 +99,22 @@ std::vector layernorm_fwd_fp8_noalloc( auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + transformer_engine::TensorWrapper workspace; + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Allocate workspaces auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {ln_out, mu, rsigma}; } @@ -194,7 +178,7 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -204,27 +188,21 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dgamma_cu = makeTransformerEngineTensor(dgamma); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {dx, dgamma}; } @@ -255,9 +233,6 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a const int scale_inv_offset) { using namespace transformer_engine; - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - // Tensor dimensions size_t N = static_cast(input.size(0)); size_t H = static_cast(input.size(1)); @@ -277,24 +252,22 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a auto rsigma_cu = makeTransformerEngineTensor(rsigma); // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + transformer_engine::TensorWrapper workspace; + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Allocate workspaces auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); return {ln_out, rsigma}; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 39679ed669..8856553c54 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, + "Fused SwiGLU backward + FP8 cast + FP8 transpose", + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 56f6b56769..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -196,6 +196,75 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { + using namespace transformer_engine; + + // Tensor dimensions + auto outer_dim = [](const at::Tensor& tensor) -> size_t { + return tensor.numel() / tensor.size(-1); + }; + const auto M = outer_dim(grad_output); + const auto N = static_cast(grad_output.size(-1)); + + // Check tensor dims + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", + M, ", but found ", outer_dim(grad_input)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input_transpose.dim() == 2, + "Expected grad input transpose tensor to have 2 dims, but found ", + grad_input_transpose.dim()); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(1) == M, + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); + + // Check tensor format + NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); + NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); + NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); + NVTE_CHECK(grad_input_transpose.is_contiguous(), + "Expected grad input transpose tensor to be contiguous"); + NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), + "Expected grad output tensor and input tensor to have same dtype"); + NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, + "Expected grad input tensor to be uint8 buffer"); + NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, + "Expected grad input transpose tensor to be uint8 buffer"); + + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + auto dy_cu = makeTransformerEngineTensor(grad_output); + auto x_cu = makeTransformerEngineTensor(input); + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + + // Launch kernel + nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + void fused_multi_cast_transpose_base(std::vector input_list, std::vector scale_dptr_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f95ba515cb..15f20c81e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -109,8 +109,6 @@ def reset(cls) -> None: cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None @classmethod @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_weights: bool, fp8_recipe: DelayedScaling, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + return f"{fwd_bwd_key}_{autocast_key}" @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" - forward, fp8_weights, autocast_key = key.split("_", 2) + forward, autocast_key = key.split("_", 1) forward = forward == "forward" - fp8_weights = fp8_weights == "True" - return forward, fp8_weights, autocast_key + return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ The amax reduction process happens completely outside the FP8 modules. @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[index_in_buffer] = [] for forward in (True, False): - # This algorithm creates a two-way map with `autocast_to_fp8_params` and - # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights - # in an autocasted region and cross reference them in `float8_tensor.py` - # to perform the forward amax reduction. fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue - if forward and fp8_weights is not None: - autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"] - ) - fp8_weight_set = {id(w._data) for w in fp8_weights} - if autocast_key not in cls.autocast_to_fp8_params: - cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set - else: - cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ - autocast_key - ].union(fp8_weight_set) - # Identify correct autocast key for a given param. - for w in fp8_weight_set: - cls.fp8_param_to_autocast[w] = autocast_key - - key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty def reduce_and_update_fp8_tensors( cls, forward: bool = True, - fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. - fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - # Only skip a forward update when `fp8_weights` is explicitly set to `True` - # (inside optimizer) and the current key is not an `fp8_weight_update` key. - # For other cases, we need to reduce because of activation tensors. - # TODO(ksivaman) consider separate weight and activation fp8_tensors. - if fwd_update and fp8_weights and not fp8_weights_update: - continue if len(amax_buffer) == 0: continue @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -475,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"]) @contextmanager diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ed0ed1c008..f44500f7f2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -12,6 +12,7 @@ from torch._C import _graph_pool_handle from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -63,6 +64,7 @@ def _make_graphed_callables( sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -173,11 +175,14 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for _ in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] @@ -201,13 +206,55 @@ def _make_graphed_callables( # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. torch.cuda.synchronize() - with torch.cuda.stream(torch.cuda.Stream()): + + # Get warmup func and func_idx. + warmup_func_idx = [] + warmup_func = [] + if _order is None: for func_idx, func in enumerate(callables): + warmup_func_idx.append(func_idx) + warmup_func.append(func) + else: + fwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + m_chunk = c_id - 1 + for l_no in range(num_layers): + func = callables[m_chunk * num_layers + l_no] + func_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) + warmup_func_idx.append(func_idx) + warmup_func.append(func) + fwd_idx[m_chunk] += 1 + assert len(warmup_func) == len( + sample_args + ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." + assert len(warmup_func_idx) == len( + set(warmup_func_idx) + ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument + if isinstance(module, TransformerEngineBaseModule): + visited_te_modules.add(module) + + # Run warmup and do the above filtering. + with torch.cuda.stream(torch.cuda.Stream()): + for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -216,6 +263,11 @@ def _make_graphed_callables( allow_unused=allow_unused_input, ) del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. + for module in func.modules(): + if hasattr(module, "is_first_microbatch"): + module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -269,6 +321,7 @@ def _make_graphed_callables( grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs @@ -320,6 +373,7 @@ def _make_graphed_callables( grad_outputs=tuple(o for o in static_grad_outputs if o is not None), only_inputs=True, allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that @@ -462,10 +516,23 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params() + m.fp8_meta, ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -538,9 +605,11 @@ def make_graphed_callables( fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -567,6 +636,8 @@ def make_graphed_callables( pool: (tuple of) int, default = `None`, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. + retain_graph_in_backward: bool, default = `False` + Whether to set retain_graph=True in backward graph capture. FP8-related parameters ---------------------- @@ -579,6 +650,9 @@ def make_graphed_callables( using a higher precision. fp8_recipe: recipe.DelayedScaling, default = `None` recipe used for FP8 training. + fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. fp8_weight_caching: bool, default = `False` Whether or not to cache FP8 weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward @@ -607,7 +681,11 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=True, ): outputs = old_forward(*args, **kwargs) return outputs @@ -644,6 +722,7 @@ def forward_func(*args, **kwargs): sample_kwargs=sample_kwargs, _order=_order, pool=pool, + retain_graph_in_backward=retain_graph_in_backward, ) # Ensures warmup does not affect numerics for ops such as dropout. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5ca34f7597..68105617f0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -589,20 +589,50 @@ def reset(key): def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" - state = None + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # We have experienced problems (e.g. in ONNX export) with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state if needed + state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - if fp8_checkpoint: + + # Copy tensors to CPU and store state = {} - state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale - state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv - state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history - state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale - state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv - state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - - # Store other pickelable values. + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) + state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv) + + # Store other pickelable values extra = {} for k, v in self.fp8_meta.items(): if k != "buffer_index_and_autocast_key" and isinstance( @@ -611,12 +641,10 @@ def get_extra_state(self) -> torch.Tensor: extra[k] = v state["extra_fp8_variables"] = extra - if is_in_onnx_export_mode(): - state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8) - else: - state_serialized = io.BytesIO() - torch.save(state, state_serialized) - + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized def set_extra_state(self, state: torch.Tensor) -> None: @@ -624,9 +652,12 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return + # Load state if isinstance(state, torch.Tensor): + # Default format: byte tensor with pickled data state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): + # Deprecated format with io.BytesIO state.seek(0) state = torch.load(state, map_location="cuda") else: @@ -635,20 +666,32 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Load extra items. + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] - # Initialize before loading. + # Initialize before loading self.init_fp8_meta_tensors() - self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) - self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) - self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) - self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) - self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"]) - self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"]) + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst.copy_(src, non_blocking=True) + + # Load tensors + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) + copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv) + torch.cuda.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" @@ -763,9 +806,7 @@ def prepare_forward( ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params() - ) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 32142cf48c..b42079d299 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fbf1b97704..92b37fcb07 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1152,7 +1152,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 64e8c9ce36..1a651474bf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1484,7 +1484,10 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fed467210..9492725f56 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -938,8 +938,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index f3651ecc19..bd7db1f775 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3dd8f64229..d6f4940c58 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,6 +4,7 @@ """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..a2e5a24a85 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,390 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +import abc +from typing import Optional + +import torch + +import transformer_engine_torch +from ...constants import TE_DType +from ...cpp_extensions import ( + geglu as tex_geglu, + gelu as tex_gelu, + reglu as tex_reglu, + relu as tex_relu, + swiglu as tex_swiglu, + fp8_dswiglu_cast_transpose_fused, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import clear_tensor_data, devices_match +from ..op import BasicOperation, OperationContext + + +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = input_ + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype != dtype: + x = x.to(dtype=dtype) + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + with_fp8_output = False + output_fp8_meta = None + output_dtype = TE_DType[dtype] + output_fp8_scale_inv = None + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: + with_fp8_output = True + fp8_meta = next_op.get_fp8_meta("input") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + output_fp8_meta = fp8_meta[fp8_meta_key] + output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + + # Launch kernel + y = self._activation_forward_impl( + x, + output_fp8_meta, + 0, + output_dtype, + scale_inv=output_fp8_scale_inv, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + if with_fp8_output: + y = Float8Tensor( + data=y, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=output_dtype, + fp8_scale_inv=output_fp8_scale_inv, + dtype=dtype, + ) + + # Save state for backward pass + ctx.save_for_backward(x) + ctx.fp8_enabled = fp8_enabled + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgelu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.drelu(*args, **kwargs) + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgeglu(*args, **kwargs) + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dreglu(*args, **kwargs) + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dswiglu(*args, **kwargs) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Tensor attributes + dtype = x.dtype + device = x.device + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, device) or dy.dtype != dtype: + dy = dy.to(device=device, dtype=dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Check if FP8 is enabled + with_fp8_grad_input = False + grad_input_fp8_meta = None + grad_input_dtype = TE_DType[dtype] + grad_input_fp8_scale_inv = None + if ( + ctx.fp8_enabled + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + ): + with_fp8_grad_input = True + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + grad_input_fp8_meta = fp8_meta[fp8_meta_key] + grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + + # Launch kernel + if with_fp8_grad_input: + # Fused with FP8 cast-transpose + input_dims = x.size() + flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] + flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] + dx = torch.empty(input_dims, dtype=torch.uint8, device=device) + dx_t = torch.empty( + (flat_input_dims[1], flat_input_dims[0]), + dtype=torch.uint8, + device=device, + ) + fp8_dswiglu_cast_transpose_fused( + dy.reshape(flat_output_dims), + x.reshape(flat_input_dims), + grad_input=dx.reshape(flat_input_dims), + grad_input_transpose=dx_t, + otype=grad_input_dtype, + fp8_meta=grad_input_fp8_meta, + fp8_meta_index=0, + scale_inv=grad_input_fp8_scale_inv, + ) + dx = Float8Tensor( + data=dx, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=grad_input_dtype, + fp8_scale_inv=grad_input_fp8_scale_inv, + dtype=dtype, + ) + dx._transpose = dx_t + dx._transpose_invalid = False + else: + # Standard impl + dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) + # # Clear input tensor if possible + # if ctx.prev_op is not None: + # clear_tensor_data(x) + + return dx, () diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 99c9c493db..710f838581 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -84,28 +89,23 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed dtype = canonicalize_dtype(dtype) weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) bias = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -143,17 +143,18 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight bias = self.bias - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) - if bias.device.type != "cuda": - bias = torch.empty_like(bias, device=self.device) - else: - bias = bias.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) # Initialize values if self.zero_centered_gamma: @@ -184,17 +185,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) @@ -266,6 +271,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -282,9 +288,12 @@ def op_backward( # Saved tensors from forward pass x, means, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -312,6 +321,6 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) - grad_bias = reshape(db, self._shape) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 4f0e2ddc22..84f05ce713 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -83,22 +88,17 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=canonicalize_dtype(dtype), ) weight = torch.nn.Parameter(weight) @@ -133,12 +133,15 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values if self.zero_centered_gamma: @@ -165,17 +168,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) if isinstance(x, QuantizedTensor): @@ -241,6 +248,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -257,9 +265,12 @@ def op_backward( # Saved tensors from forward pass x, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -285,5 +296,5 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) + grad_weight = reshape(dw, weight_dims) return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6fcb435e5c..8b2a04cff8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -135,7 +135,11 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 0bb6f25db8..c55e0f7c19 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,7 +19,7 @@ FP8GlobalStateManager, get_default_fp8_recipe, ) -from ._common import canonicalize_device, is_float8_tensor +from ._common import canonicalize_device @dataclasses.dataclass @@ -379,10 +379,8 @@ def pre_forward( self.get_fp8_meta("input"), ) if self.num_fp8_scales("param"): - fp8_params = list(filter(is_float8_tensor, self.parameters())) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( self.get_fp8_meta("param"), - fp8_weights=(fp8_params if fp8_params else None), ) if self.num_fp8_scales("grad_output"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( @@ -505,7 +503,7 @@ def forward( basic_op_kwargs=[kwargs], ) - def get_extra_state(self) -> Optional[torch.Tensor]: + def get_extra_state(self) -> torch.Tensor: """Serialize extra state Contains metadata for FP8 casting. @@ -516,7 +514,7 @@ def get_extra_state(self) -> Optional[torch.Tensor]: # # (1) PyTorch's "extra state" infrastructure might be able to # support any picklable type, but they make no guarantees. - # It seems that ONNX export experiences issues with + # We have experienced problems (e.g. in ONNX export) with # non-tensor extra state. # (2) PyTorch's checkpointing infrastructure does not remap # devices for "extra state" like it does for "state dict". @@ -534,7 +532,7 @@ def get_extra_state(self) -> Optional[torch.Tensor]: self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") ) if not has_fp8_state: - return None + return torch.Tensor() def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -588,7 +586,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: def set_extra_state(self, state: Optional[torch.Tensor]) -> None: """Load extra state""" - if state is None: + if state is None or state.numel() == 0: return # Deserialize state from byte tensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 36136292df..7ace68a222 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -74,30 +74,6 @@ def backward( return grad, None -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @@ -676,9 +652,6 @@ def quantize_( ) dst._transpose_invalid = False - # Callback hook to perform amax reduction after optimizer step - post_optimizer_step_fwd_amax_reduction(self) - return self @classmethod