You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
And it seems it just hardcoded two platforms cuda and rocm inside this function. However I notice that there are some utilities for such check in xla_bridge, like _platform_aliases, _alias_to_platforms and expand_platform_alias:
so that it can be extensible and consistent with _platform_alias. We can just change the alias list if we want to add a new GPU platform, instead of taking care of this function.
WDYT? I'm glad to submit a patch if it looks good to people in the JAX community : )
The text was updated successfully, but these errors were encountered:
In
jax._src.xla_bridge
we have a function namedis_gpu
to determine whether a platform is a GPU platform:jax/jax/_src/xla_bridge.py
Lines 846 to 847 in 7dd401c
And it seems it just hardcoded two platforms
cuda
androcm
inside this function. However I notice that there are some utilities for such check inxla_bridge
, like_platform_aliases
,_alias_to_platforms
andexpand_platform_alias
:jax/jax/_src/xla_bridge.py
Lines 838 to 844 in 7dd401c
I think maybe we can change this function to:
so that it can be extensible and consistent with
_platform_alias
. We can just change the alias list if we want to add a new GPU platform, instead of taking care of this function.WDYT? I'm glad to submit a patch if it looks good to people in the JAX community : )
The text was updated successfully, but these errors were encountered: