Skip to content

Commit

Permalink
formatted
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 11, 2024
1 parent 26ba7be commit 1b961b0
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,24 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)
);
dict["te_layernorm_forward_fp8_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)
);
dict["te_layernorm_backward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)
);
dict["te_rmsnorm_forward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)
);
dict["te_rmsnorm_forward_fp8_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)
);
dict["te_rmsnorm_backward_ffi"] = pybind11::dict(
pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)
);
dict["te_layernorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_layernorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
dict["te_rmsnorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));

// Attention
pybind11::dict fused_attn_forward_ffi;
Expand Down

0 comments on commit 1b961b0

Please sign in to comment.