You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX numpy seems to be slower than numpy unless you also utilize JIT or offload to an accelerator. When JAX is in use, the code base will convert a numpy array to a JAX numpy array prior to calling a JIT function, but doesn't convert the array back once the JIT function returns. An investigation should be done to see if a conversion back introduces a speed improvement over the current process.
The text was updated successfully, but these errors were encountered:
JAX numpy seems to be slower than numpy unless you also utilize JIT or offload to an accelerator. When JAX is in use, the code base will convert a numpy array to a JAX numpy array prior to calling a JIT function, but doesn't convert the array back once the JIT function returns. An investigation should be done to see if a conversion back introduces a speed improvement over the current process.
The text was updated successfully, but these errors were encountered: