-
Notifications
You must be signed in to change notification settings - Fork 363
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add FP8 fused attention to TE for PyTorch Signed-off-by: Charlene Yang <[email protected]> * add license for cudnn-frontend, modify installation requirements, and refactor some headers for aesthetics Signed-off-by: Charlene Yang <[email protected]> * add c api docs for fused attention Signed-off-by: Charlene Yang <[email protected]> * add exception for unsupported precision/sequence length combinations Signed-off-by: Charlene Yang <[email protected]> * fix installation requirement for non fused attn use cases Signed-off-by: Charlene Yang <[email protected]> * fix docs for fused-attn Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * prefix enums with NVTE_ and replace old MHA_Matrix with NVTE_QKV_Matrix Signed-off-by: Charlene Yang <[email protected]> * minor fixes based on PR comments Signed-off-by: Charlene Yang <[email protected]> * fix description for kvpacked fwd Signed-off-by: Charlene Yang <[email protected]> * fix description of Bias in C api Signed-off-by: Charlene Yang <[email protected]> * minor fixes for cudnn requirement and description for QKV tensors Signed-off-by: Charlene Yang <[email protected]> * fix QKV layout description and support matrix for C api Signed-off-by: Charlene Yang <[email protected]> * add asserts to cpp_extensions for qkv layout/bias type/attn mask type Signed-off-by: Charlene Yang <[email protected]> * fix typo precision Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Charlene Yang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
- Loading branch information
1 parent
c340730
commit 989a53a
Showing
29 changed files
with
4,720 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "3rdparty/googletest"] | ||
path = 3rdparty/googletest | ||
url = https://github.com/google/googletest.git | ||
[submodule "3rdparty/cudnn-frontend"] | ||
path = 3rdparty/cudnn-frontend | ||
url = https://github.com/NVIDIA/cudnn-frontend.git |
Submodule cudnn-frontend
added at
e7f643
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
.. | ||
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
See LICENSE for license information. | ||
|
||
fused_attn.h | ||
============ | ||
|
||
.. doxygenfile:: fused_attn.h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
add_library(CUDNN::cudnn_all INTERFACE IMPORTED) | ||
|
||
find_path( | ||
CUDNN_INCLUDE_DIR cudnn.h | ||
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS} | ||
PATH_SUFFIXES include | ||
) | ||
|
||
function(find_cudnn_library NAME) | ||
string(TOUPPER ${NAME} UPPERCASE_NAME) | ||
|
||
find_library( | ||
${UPPERCASE_NAME}_LIBRARY ${NAME} | ||
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR} | ||
PATH_SUFFIXES lib64 lib/x64 lib | ||
) | ||
|
||
if(${UPPERCASE_NAME}_LIBRARY) | ||
add_library(CUDNN::${NAME} UNKNOWN IMPORTED) | ||
set_target_properties( | ||
CUDNN::${NAME} PROPERTIES | ||
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} | ||
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY} | ||
) | ||
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.") | ||
else() | ||
message(STATUS "${NAME} not found.") | ||
endif() | ||
|
||
|
||
endfunction() | ||
|
||
find_cudnn_library(cudnn) | ||
find_cudnn_library(cudnn_adv_infer) | ||
find_cudnn_library(cudnn_adv_train) | ||
find_cudnn_library(cudnn_cnn_infer) | ||
find_cudnn_library(cudnn_cnn_train) | ||
find_cudnn_library(cudnn_ops_infer) | ||
find_cudnn_library(cudnn_ops_train) | ||
|
||
include (FindPackageHandleStandardArgs) | ||
find_package_handle_standard_args( | ||
CUDNN REQUIRED_VARS | ||
CUDNN_INCLUDE_DIR CUDNN_LIBRARY | ||
) | ||
|
||
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) | ||
|
||
message(STATUS "cuDNN: ${CUDNN_LIBRARY}") | ||
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") | ||
|
||
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") | ||
|
||
else() | ||
|
||
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found") | ||
|
||
endif() | ||
|
||
target_include_directories( | ||
CUDNN::cudnn_all | ||
INTERFACE | ||
$<INSTALL_INTERFACE:include> | ||
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}> | ||
) | ||
|
||
target_link_libraries( | ||
CUDNN::cudnn_all | ||
INTERFACE | ||
CUDNN::cudnn_adv_train | ||
CUDNN::cudnn_ops_train | ||
CUDNN::cudnn_cnn_train | ||
CUDNN::cudnn_adv_infer | ||
CUDNN::cudnn_cnn_infer | ||
CUDNN::cudnn_ops_infer | ||
CUDNN::cudnn | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.