Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Canonicalize the dtype for the better user experience #480

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
"""
convert jax dtype to TE dtype
"""
jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
if jax_dtype == jnp.float32:
return TEDType.kFloat32
if jax_dtype == jnp.float16:
Expand Down Expand Up @@ -1626,6 +1627,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1757,6 +1759,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down Expand Up @@ -1908,6 +1911,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads

dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
Expand Down