Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Canonicalize the dtype for the better user experience #480

Merged
merged 1 commit into from
Oct 20, 2023

Conversation

zlsh80826
Copy link
Collaborator

Use jax.dtypes.canonicalize_dtype to wrap the dtype may from the users to ensure both jnp.bfloat16 and 'bfloat16' work.

@zlsh80826 zlsh80826 added the enhancement New feature or request label Oct 17, 2023
@mingxu1067
Copy link
Collaborator

LGTM. Might have some conflicts with #472

@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826
Copy link
Collaborator Author

zlsh80826 commented Oct 20, 2023

@timmoon10 All CIs are passed, please help merge it when you are available, thanks :)

@ksivaman ksivaman merged commit 2a86df2 into NVIDIA:main Oct 20, 2023
denera pushed a commit to denera/TransformerEngine that referenced this pull request Oct 23, 2023
canonicalize the dtype for the better user experience

Signed-off-by: Reese Wang <[email protected]>
mingxu1067 pushed a commit to mingxu1067/TransformerEngine that referenced this pull request Nov 3, 2023
canonicalize the dtype for the better user experience

Signed-off-by: Reese Wang <[email protected]>
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Nov 13, 2023
canonicalize the dtype for the better user experience

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants