Skip to content

Commit

Permalink
[JAX] Added prepare phase for the FusedAttnForwardFFI (#1313)
Browse files Browse the repository at this point in the history
* added prepare phase for the FusedAttnForwardFFI

* enabled FusedAttnForwardFFI by default

* moved prepare phase into pybind

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Nov 7, 2024
1 parent 4d65073 commit e5ffaa7
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 2 deletions.
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/common/cudnn_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../fused_attn/utils.h"
#include "transformer_engine/cudnn.h"

namespace transformer_engine {

void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}

} // namespace transformer_engine
29 changes: 29 additions & 0 deletions transformer_engine/common/include/transformer_engine/cudnn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file cudnn.h
* \brief Helper for cuDNN initialization
*/

#ifndef TRANSFORMER_ENGINE_CUDNN_H_
#define TRANSFORMER_ENGINE_CUDNN_H_

#include "transformer_engine.h"

/*! \namespace transformer_engine
*/
namespace transformer_engine {

/*! \brief TE/JAX cudaGraph requires the cuDNN initialization to happen outside of the capturing
* region. This function is a helper to call cudnnCreate() which allocate memory for the handle.
* The function will be called in the initialize() phase of the related XLA custom calls.
*/

void nvte_cudnn_handle_init();

} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_CUDNN_H_
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def lowering(

wkspace_aval = ctx.avals_out[-1]

if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI")):
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

} // namespace jax
} // namespace transformer_engine

Expand Down
24 changes: 24 additions & 0 deletions transformer_engine/jax/csrc/extensions/cudnn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "transformer_engine/cudnn.h"

#include "extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type CudnnHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cudnn_handle_init();
return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(CudnnHandleInitHandler, CudnnHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
4 changes: 4 additions & 0 deletions transformer_engine/jax/csrc/extensions/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ namespace jax {

using Buffer_Type = xla::ffi::AnyBuffer;
using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Variadic_Buffer_Type = xla::ffi::RemainingArgs;
using Variadic_Result_Type = xla::ffi::RemainingRets;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary;

constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};

DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ pybind11::dict Registrations() {
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);

// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
pybind11::dict fused_attn_forward_ffi;
fused_attn_forward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;

return dict;
}
Expand Down

0 comments on commit e5ffaa7

Please sign in to comment.