Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Bug Fix: Softmax FFIs with correct Encapsulates #1375

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
Expand Down Expand Up @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
assert x_shape[-2] == 2 or x_shape[-2] == 1
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)

Expand Down Expand Up @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval

return out_aval

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta, abstractmethod
from functools import partial

from jax import core
from jax.extend import core
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
Expand Down
14 changes: 7 additions & 7 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
Expand Down Expand Up @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs):

mu_rsigama_dtype = jnp.float32

out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)

assert gamma_aval.size == beta_aval.size
Expand Down Expand Up @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
assert mu_dtype == rsigma_dtype == jnp.float32

dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
dx_aval = dz_aval
dgamma_aval = dbeta_aval = gamma_aval

(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
Expand Down Expand Up @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs):

rsigama_dtype = jnp.float32

out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)

hidden_size = gamma_aval.size
Expand Down Expand Up @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
assert rsigma_aval.shape == x_aval.shape[:-1]
assert rsigma_dtype == jnp.float32

dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = core.raise_to_shaped(gamma_aval)
dx_aval = dz_aval
dgamma_aval = gamma_aval

(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/jax/cpp_extensions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
Expand Down Expand Up @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor):
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1

out_aval = core.raise_to_shaped(logits_aval)
out_aval = logits_aval
return out_aval

@staticmethod
Expand Down Expand Up @@ -237,7 +237,7 @@ def backward_abstract(

assert dz_aval.shape == softmax_out_aval.shape

dx_aval = core.raise_to_shaped(dz_aval)
dx_aval = dz_aval
return dx_aval

@staticmethod
Expand Down Expand Up @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar
assert mask_shape[-2] == q_seqlen
assert mask_shape[-1] == k_seqlen

out_aval = core.raise_to_shaped(logits_aval)
out_aval = logits_aval
return out_aval

@staticmethod
Expand Down
19 changes: 8 additions & 11 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,23 @@ pybind11::dict Registrations() {
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] =
EncapsulateFunction(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] =
EncapsulateFunction(DGatedActLuCastTransposeHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler);

// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);

// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] =
Expand Down
Loading