Skip to content

Commit

Permalink
[JAX] Canonicalize the dtype for the better user experience (#480)
Browse files Browse the repository at this point in the history
canonicalize the dtype for the better user experience

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Oct 20, 2023
1 parent 1afb625 commit 2a86df2
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
"""
convert jax dtype to TE dtype
"""
jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
if jax_dtype == jnp.float32:
return TEDType.kFloat32
if jax_dtype == jnp.float16:
Expand Down Expand Up @@ -1626,6 +1627,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1757,6 +1759,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1908,6 +1911,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down

0 comments on commit 2a86df2

Please sign in to comment.