From 1b961b04b558e93be432fcbf3ef124a4456b4e71 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Dec 2024 14:50:26 -0800 Subject: [PATCH] formatted Signed-off-by: Phuong Nguyen --- .../jax/csrc/extensions/pybind.cpp | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0188a23059..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,30 +83,24 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler) - ); - dict["te_layernorm_forward_fp8_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler) - ); - dict["te_layernorm_backward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler) - ); - dict["te_rmsnorm_forward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler) - ); - dict["te_rmsnorm_forward_fp8_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler) - ); - dict["te_rmsnorm_backward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler) - ); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi;