Skip to content

Commit

Permalink
[JAX] Canonicalize the dtype for the better user experience (NVIDIA#480)
Browse files Browse the repository at this point in the history
canonicalize the dtype for the better user experience

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored and mingxu1067 committed Nov 3, 2023
1 parent 7c3e9ce commit 3cf9460
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
"""
convert jax dtype to TE dtype
"""
jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
if jax_dtype == jnp.float32:
return TEDType.kFloat32
if jax_dtype == jnp.float16:
Expand Down Expand Up @@ -963,7 +964,7 @@ def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
@staticmethod
def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None): # pylint: disable=unused-argument
"""
softmax_backward infer_sharding_from_operands
softmax_backward abstract
"""
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
Expand Down Expand Up @@ -1071,6 +1072,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1225,6 +1227,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1329,9 +1332,8 @@ def partition(scale_factor, mesh, arg_infos, result_infos):
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
out_spec = logits_spec
arg_shardings = (logits_spec, mask_spec)
out_shardings = out_spec
out_shardings = logits_spec
impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings

Expand Down Expand Up @@ -1442,6 +1444,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down

0 comments on commit 3cf9460

Please sign in to comment.