From 486f5fb788592b7ff2d53f09f0ced957077a02b6 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Dec 2024 09:33:25 -0800 Subject: [PATCH 1/4] fix ctx.aval_out indexing for workspace Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/normalization.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0b7df0b5a8..69d7962b62 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -1088,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), From a149fb08f97282d8b6acc24ff700dd151230c03b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Dec 2024 09:34:09 -0800 Subject: [PATCH 2/4] added cudnn init to prepare phase of norm custom calls Signed-off-by: Phuong Nguyen --- .../jax/csrc/extensions/pybind.cpp | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..0188a23059 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,12 +83,30 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler) + ); + dict["te_layernorm_forward_fp8_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler) + ); + dict["te_layernorm_backward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler) + ); + dict["te_rmsnorm_forward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler) + ); + dict["te_rmsnorm_forward_fp8_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler) + ); + dict["te_rmsnorm_backward_ffi"] = pybind11::dict( + pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler) + ); // Attention pybind11::dict fused_attn_forward_ffi; From 26ba7be0fa00587f6eaea78d64e8ac73c08e8177 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Dec 2024 14:49:06 -0800 Subject: [PATCH 3/4] add thread_local for norm registry instance Signed-off-by: Phuong Nguyen --- transformer_engine/common/normalization/common.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 8a8df63ba4..d1d56d5cc9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { class NormalizationPlanRegistry { public: - // TODO thread-safe static NormalizationPlanRegistry& getInstance() { - static NormalizationPlanRegistry instance; + static thread_local NormalizationPlanRegistry instance; return instance; } From 1b961b04b558e93be432fcbf3ef124a4456b4e71 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 11 Dec 2024 14:50:26 -0800 Subject: [PATCH 4/4] formatted Signed-off-by: Phuong Nguyen --- .../jax/csrc/extensions/pybind.cpp | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0188a23059..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,30 +83,24 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler) - ); - dict["te_layernorm_forward_fp8_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler) - ); - dict["te_layernorm_backward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler) - ); - dict["te_rmsnorm_forward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler) - ); - dict["te_rmsnorm_forward_fp8_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler) - ); - dict["te_rmsnorm_backward_ffi"] = pybind11::dict( - pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), - pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler) - ); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi;