diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 6858c58947..3784689f9a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -46,7 +46,6 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES - pycudnn.cpp cudnn_utils.cpp transformer_engine.cpp common.cu diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 80b66de4bd..35e2d11799 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -4,13 +4,70 @@ * See LICENSE for license information. ************************************************************************/ -#include "../fused_attn/utils.h" +#include "cudnn_utils.h" + +#include "./util/logging.h" #include "transformer_engine/cudnn.h" namespace transformer_engine { +// get cuDNN data type +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kInt32: + return CUDNN_DATA_INT32; + case DType::kInt64: + return CUDNN_DATA_INT64; + case DType::kFloat16: + return CUDNN_DATA_HALF; + case DType::kFloat32: + return CUDNN_DATA_FLOAT; + case DType::kBFloat16: + return CUDNN_DATA_BFLOAT16; + case DType::kFloat8E4M3: + return CUDNN_DATA_FP8_E4M3; + case DType::kFloat8E5M2: + return CUDNN_DATA_FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} + +// get cuDNN data type +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kInt32: + return cudnn_frontend::DataType_t::INT32; + case DType::kInt64: + return cudnn_frontend::DataType_t::INT64; + case DType::kFloat16: + return cudnn_frontend::DataType_t::HALF; + case DType::kFloat32: + return cudnn_frontend::DataType_t::FLOAT; + case DType::kBFloat16: + return cudnn_frontend::DataType_t::BFLOAT16; + case DType::kFloat8E4M3: + return cudnn_frontend::DataType_t::FP8_E4M3; + case DType::kFloat8E5M2: + return cudnn_frontend::DataType_t::FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} + void nvte_cudnn_handle_init() { auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); } } // namespace transformer_engine + +namespace cudnn_frontend { + +// This is needed to define the symbol `cudnn_dlhandle` +// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING +// to enable dynamic loading. +void *cudnn_dlhandle = nullptr; + +} // namespace cudnn_frontend diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h new file mode 100644 index 0000000000..d2827b637a --- /dev/null +++ b/transformer_engine/common/cudnn_utils.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_ +#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_ + +#include +#include +#include + +#include +#include + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); + +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); + +class cudnnExecutionPlanManager { + public: + static cudnnExecutionPlanManager &Instance() { + static thread_local cudnnExecutionPlanManager instance; + return instance; + } + + cudnnHandle_t GetCudnnHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { cudnnCreate(&handle_); }); + return handle_; + } + + ~cudnnExecutionPlanManager() {} + + private: + cudnnHandle_t handle_ = nullptr; +}; + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4ea0ea5741..9cde765401 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/fused_attn.h" #include "../common.h" +#include "../cudnn_utils.h" #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1a555a4999..f242502261 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -13,6 +13,7 @@ #include #include "../common.h" +#include "../cudnn_utils.h" #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index d3422de481..9341ebf5f9 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -12,6 +12,7 @@ #include #include "../common.h" +#include "../cudnn_utils.h" #include "fused_attn_f16_max512_seqlen.h" #include "utils.h" diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fb7765e1a4..f8fe458219 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "../common.h" +#include "../cudnn_utils.h" #include "../util/system.h" #include "fused_attn_fp8.h" #include "utils.h" diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index ca00218d9a..a053c55fb6 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -8,6 +8,7 @@ #include #include "../common.h" +#include "../cudnn_utils.h" #include "transformer_engine/fused_attn.h" #include "utils.h" @@ -495,50 +496,4 @@ size_t get_max_tokens(size_t num_tokens) { } } // namespace fused_attn - -// get cuDNN data type -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kInt32: - return CUDNN_DATA_INT32; - case DType::kInt64: - return CUDNN_DATA_INT64; - case DType::kFloat16: - return CUDNN_DATA_HALF; - case DType::kFloat32: - return CUDNN_DATA_FLOAT; - case DType::kBFloat16: - return CUDNN_DATA_BFLOAT16; - case DType::kFloat8E4M3: - return CUDNN_DATA_FP8_E4M3; - case DType::kFloat8E5M2: - return CUDNN_DATA_FP8_E5M2; - default: - NVTE_ERROR("Invalid cuDNN data type. \n"); - } -} - -// get cuDNN data type -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kInt32: - return cudnn_frontend::DataType_t::INT32; - case DType::kInt64: - return cudnn_frontend::DataType_t::INT64; - case DType::kFloat16: - return cudnn_frontend::DataType_t::HALF; - case DType::kFloat32: - return cudnn_frontend::DataType_t::FLOAT; - case DType::kBFloat16: - return cudnn_frontend::DataType_t::BFLOAT16; - case DType::kFloat8E4M3: - return cudnn_frontend::DataType_t::FP8_E4M3; - case DType::kFloat8E5M2: - return cudnn_frontend::DataType_t::FP8_E5M2; - default: - NVTE_ERROR("Invalid cuDNN data type. \n"); - } -} } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index c060c4907d..f790d3b567 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -140,29 +140,8 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at size_t get_max_batch_size(size_t batch_size); size_t get_max_tokens(size_t num_tokens); -} // namespace fused_attn - -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } - - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } - - ~cudnnExecutionPlanManager() {} - - private: - cudnnHandle_t handle_ = nullptr; -}; +} // namespace fused_attn } // namespace transformer_engine #endif diff --git a/transformer_engine/common/pycudnn.cpp b/transformer_engine/common/pycudnn.cpp deleted file mode 100644 index 7d06f332cb..0000000000 --- a/transformer_engine/common/pycudnn.cpp +++ /dev/null @@ -1,14 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -namespace cudnn_frontend { - -// This is needed to define the symbol `cudnn_dlhandle` -// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING -// to enable dynamic loading. -void *cudnn_dlhandle = nullptr; - -} // namespace cudnn_frontend