Skip to content

Commit

Permalink
rm jax deprecated features
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 13, 2024
1 parent 23ff830 commit 76cd159
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
Expand Down Expand Up @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
assert x_shape[-2] == 2 or x_shape[-2] == 1
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)

Expand Down Expand Up @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval

return out_aval

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta, abstractmethod
from functools import partial

from jax import core
from jax.extend import core
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
Expand Down
14 changes: 7 additions & 7 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
Expand Down Expand Up @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs):

mu_rsigama_dtype = jnp.float32

out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)

assert gamma_aval.size == beta_aval.size
Expand Down Expand Up @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
assert mu_dtype == rsigma_dtype == jnp.float32

dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
dx_aval = dz_aval
dgamma_aval = dbeta_aval = gamma_aval

(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
Expand Down Expand Up @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs):

rsigama_dtype = jnp.float32

out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)

hidden_size = gamma_aval.size
Expand Down Expand Up @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
assert rsigma_aval.shape == x_aval.shape[:-1]
assert rsigma_dtype == jnp.float32

dx_aval = core.raise_to_shaped(dz_aval)
dgamma_aval = core.raise_to_shaped(gamma_aval)
dx_aval = dz_aval
dgamma_aval = gamma_aval

(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/jax/cpp_extensions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
Expand Down Expand Up @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor):
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1

out_aval = core.raise_to_shaped(logits_aval)
out_aval = logits_aval
return out_aval

@staticmethod
Expand Down Expand Up @@ -237,7 +237,7 @@ def backward_abstract(

assert dz_aval.shape == softmax_out_aval.shape

dx_aval = core.raise_to_shaped(dz_aval)
dx_aval = dz_aval
return dx_aval

@staticmethod
Expand Down Expand Up @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar
assert mask_shape[-2] == q_seqlen
assert mask_shape[-1] == k_seqlen

out_aval = core.raise_to_shaped(logits_aval)
out_aval = logits_aval
return out_aval

@staticmethod
Expand Down

0 comments on commit 76cd159

Please sign in to comment.