Skip to content

Commit

Permalink
Add FP8 fused attention (#155)
Browse files Browse the repository at this point in the history
* Add FP8 fused attention to TE for PyTorch

Signed-off-by: Charlene Yang <[email protected]>

* add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics

Signed-off-by: Charlene Yang <[email protected]>

* add c api docs for fused attention

Signed-off-by: Charlene Yang <[email protected]>

* add exception for unsupported precision/sequence length combinations

Signed-off-by: Charlene Yang <[email protected]>

* fix installation requirement for non fused attn use cases

Signed-off-by: Charlene Yang <[email protected]>

* fix docs for fused-attn

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix

Signed-off-by: Charlene Yang <[email protected]>

* minor fixes based on PR comments

Signed-off-by: Charlene Yang <[email protected]>

* fix description for kvpacked fwd

Signed-off-by: Charlene Yang <[email protected]>

* fix description of Bias in C api

Signed-off-by: Charlene Yang <[email protected]>

* minor fixes for cudnn requirement and description for QKV tensors

Signed-off-by: Charlene Yang <[email protected]>

* fix QKV layout description and support matrix for C api

Signed-off-by: Charlene Yang <[email protected]>

* add asserts to cpp_extensions for qkv layout/bias type/attn mask type

Signed-off-by: Charlene Yang <[email protected]>

* fix typo precision

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
3 people authored Apr 21, 2023
1 parent c340730 commit 989a53a
Show file tree
Hide file tree
Showing 29 changed files with 4,720 additions and 25 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
mkdir -p wheelhouse && \
Expand All @@ -41,6 +43,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
pip install ninja pybind11 && \
Expand All @@ -66,6 +70,8 @@ jobs:
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: |
pip install ninja pybind11 && \
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
1 change: 1 addition & 0 deletions 3rdparty/cudnn-frontend
Submodule cudnn-frontend added at e7f643
22 changes: 22 additions & 0 deletions Acknowledgements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,25 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================
cudnn-frontend

Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
9 changes: 9 additions & 0 deletions docs/api/c/fused_attn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
..
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

fused_attn.h
============

.. doxygenfile:: fused_attn.h
1 change: 1 addition & 0 deletions docs/api/c/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ directly from C/C++, without Python.
activation.h <activation>
cast.h <cast>
gemm.h <gemm>
fused_attn.h <fused_attn>
layer_norm.h <layer_norm>
softmax.h <softmax>
transformer_engine.h <transformer_engine>
Expand Down
2 changes: 2 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Prerequisites
1. Linux x86_64
2. `CUDA 11.8 <https://developer.nvidia.com/cuda-downloads>`__
3. |driver link|_ supporting CUDA 11.8 or later.
4. `cuDNN 8 <https://developer.nvidia.com/cudnn>`__ or later.
5. For FP8 fused attention, `CUDA 12.1 <https://developer.nvidia.com/cuda-downloads>`__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9 <https://developer.nvidia.com/cudnn>`__ or later.


Transformer Engine in NGC Containers
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def make_abs_path(l):
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
"3rdparty/cudnn-frontend/include",
]
if NVTE_WITH_USERBUFFERS:
if MPI_HOME:
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const std::string &typeName(DType type) {
static const std::unordered_map<DType, std::string> name_map = {
{DType::kByte, "byte"},
{DType::kInt32, "int32"},
{DType::kInt64, "int64"},
{DType::kFloat32, "float32"},
{DType::kFloat16, "float16"},
{DType::kBFloat16, "bfloat16"},
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct BytesToType<8> {

using byte = uint8_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
Expand All @@ -54,6 +55,7 @@ template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int32,
int64,
fp32,
fp16,
bf16,
Expand Down Expand Up @@ -211,6 +213,12 @@ bool isFp8Type(DType type);
{__VA_ARGS__} \
} \
break; \
case DType::kInt64: \
{ \
using type = int64; \
{__VA_ARGS__} \
} \
break; \
case DType::kFloat32: \
{ \
using type = float; \
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()

list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(CUDNN REQUIRED cudnn)
find_package(Python COMPONENTS Interpreter Development REQUIRED)

include_directories(${PROJECT_SOURCE_DIR})
Expand Down
78 changes: 78 additions & 0 deletions transformer_engine/cmake/FindCUDNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)

find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
)

function(find_cudnn_library NAME)
string(TOUPPER ${NAME} UPPERCASE_NAME)

find_library(
${UPPERCASE_NAME}_LIBRARY ${NAME}
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
)

if(${UPPERCASE_NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()


endfunction()

find_cudnn_library(cudnn)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)

include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
CUDNN REQUIRED_VARS
CUDNN_INCLUDE_DIR CUDNN_LIBRARY
)

if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)

message(STATUS "cuDNN: ${CUDNN_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")

set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")

else()

set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")

endif()

target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)

target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
CUDNN::cudnn
)

7 changes: 6 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
Expand All @@ -30,9 +33,11 @@ target_include_directories(transformer_engine PUBLIC
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDA::nvToolsExt)
CUDA::nvToolsExt
CUDNN::cudnn)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PRIVATE "${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")

# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
Expand Down
Loading

0 comments on commit 989a53a

Please sign in to comment.