Skip to content

Set correct threads_per_warp to 64 for AMD GPU #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: rocm-jaxlib-v0.5.0
Choose a base branch
from

Conversation

yaomingamd
Copy link

In xla0.5.0, some algorithms depends on warp_size 32, so that thereads_per_warp has been manually set to 32. This PR tries to reset threads_per_warp to 64 correctly for current AMD GPU and at the same, fix unittest failure/numerical issues due to this change. I have run unittest of jax, it shows that unittests have much less failure/numerical issues than original/base branch rocm-jaxlib-v0.5.0. For example, pytest tests/pallas pass all unittests.
=============== 1907 passed, 3307 skipped in 1281.18s (0:21:21) ================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant