From c54b7ccc4fd25499399618d9da564e1ed721a34b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 13 Dec 2024 07:10:00 -0800 Subject: [PATCH 1/3] softmax custom calls with correct encapsulates Signed-off-by: Phuong Nguyen --- .../jax/csrc/extensions/pybind.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a319b74d76..79fb14eaf0 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -62,25 +62,25 @@ pybind11::dict Registrations() { dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); + EncapsulateFFI(DActLuDBiasCastTransposeHandler); dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization dict["te_layernorm_forward_ffi"] = From 23ff830c6273c8dae96809c45b52110c4f95bba5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:12:52 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/pybind.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 79fb14eaf0..a986b91b30 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -61,10 +61,8 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFFI(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFFI(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); @@ -73,8 +71,7 @@ pybind11::dict Registrations() { // Softmax dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = From 76cd159ce77dd0f0a85685fa27b53979623fe5c7 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 13 Dec 2024 07:54:03 -0800 Subject: [PATCH 3/3] rm jax deprecated features Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 6 +++--- transformer_engine/jax/cpp_extensions/base.py | 2 +- .../jax/cpp_extensions/normalization.py | 14 +++++++------- transformer_engine/jax/cpp_extensions/softmax.py | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..7f09e6f900 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..3715e6f20c 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 69d7962b62..8ad7ee4fcb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..67053ecd8e 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +237,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod