Skip to content

Commit

Permalink
added cudnn init to prepare phase of norm custom calls
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 11, 2024
1 parent 486f5fb commit a149fb0
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,30 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
dict["te_layernorm_forward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)
);
dict["te_layernorm_forward_fp8_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)
);
dict["te_layernorm_backward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)
);
dict["te_rmsnorm_forward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)
);
dict["te_rmsnorm_forward_fp8_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)
);
dict["te_rmsnorm_backward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)
);

// Attention
pybind11::dict fused_attn_forward_ffi;
Expand Down

0 comments on commit a149fb0

Please sign in to comment.