Unable to use jax
#25641
-
I installed Jax as per instructions However this simple code from jax import numpy as jnp
a = jnp.array([0] * (3 * 210 * 160))
a = a.reshape((3, 210, 160))
a = jnp.resize(a, (1, 110, 84)) Causes this issue Traceback (most recent call last):
File "/home/haislich/Documents/dqn/src/dqn/dqn.py", line 5, in <module>
a = jnp.resize(a, (1, 110, 84))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 139, output: ptxas /tmp/tempfile-Workstation-c7212f129f34c799-9923-629b5d77da91f, line 5; fatal : Unsupported .version 8.3; current version is '8.0'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. What can I do ? |
Beta Was this translation helpful? Give feedback.
Answered by
pearu
Dec 20, 2024
Replies: 1 comment 5 replies
-
Hi - sorry you're having this issue! Can you give us some more information about what system you're on, what Python version you have, how you installed JAX, and what JAX version you're using? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This kind of an error may appear when cuda toolkit and nvidia driver are incompatible:
nvcc
reports 12.0 but the driver is 12.6. Check also the output ofand consider updating cuda toolkit.