Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make xla_bridge.is_gpu more extensible #25521

Open
PragmaTwice opened this issue Dec 17, 2024 · 0 comments
Open

Make xla_bridge.is_gpu more extensible #25521

PragmaTwice opened this issue Dec 17, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@PragmaTwice
Copy link

PragmaTwice commented Dec 17, 2024

In jax._src.xla_bridge we have a function named is_gpu to determine whether a platform is a GPU platform:

jax/jax/_src/xla_bridge.py

Lines 846 to 847 in 7dd401c

def is_gpu(platform):
return platform in ("cuda", "rocm")

And it seems it just hardcoded two platforms cuda and rocm inside this function. However I notice that there are some utilities for such check in xla_bridge, like _platform_aliases, _alias_to_platforms and expand_platform_alias:

jax/jax/_src/xla_bridge.py

Lines 838 to 844 in 7dd401c

def expand_platform_alias(platform: str) -> list[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])

I think maybe we can change this function to:

def is_gpu(platform):
  return platform in expand_platform_alias("gpu")

so that it can be extensible and consistent with _platform_alias. We can just change the alias list if we want to add a new GPU platform, instead of taking care of this function.

WDYT? I'm glad to submit a patch if it looks good to people in the JAX community : )

@PragmaTwice PragmaTwice added the enhancement New feature or request label Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant