You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I'm trying to install iree-jax to test GPT-2 on IREE. After running python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html, I built jaxlib from source. However, when I run lit -v tests/, I get a RuntimeError with the message "Unknown backend iree". This also happens when running models/gpt2/test_jax.py. Did I miss something during the setup process? Your help would be greatly appreciated. I have attached the error log below.
Using pure python filecheck: /home/woongq/jax/bin/filecheck
-- Testing: 5 tests, 5 workers --
FAIL: IREE_JAX :: program/trivial_kernel.py (1 of 5)
******************** TEST 'IREE_JAX :: program/trivial_kernel.py' FAILED ********************
Script:
--
: 'RUN: at line 15'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/trivial_kernel.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/trivial_kernel.py
--
Exit Code: 2
Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command stderr:
WARNING:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002429485321044922 sec
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:jax._src.dispatch:Finished tracing + transforming jit(broadcast_in_dim) in 0.0002300739288330078 sec
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001964092254638672 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.012798309326171875 sec
WARNING:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004911422729492188 sec
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[3,4]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(fn) in 0.0016129016876220703 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(fn) in 0.0081329345703125 sec
DEBUG:iree_jax:Create new Program subclass: trivial_kernel
DEBUG:root:DEFINE PY_ONLY: _linear = <Exportable Pure Func: <function TrivialKernel._linear at 0x7f91ee93ce50>>
DEBUG:iree_jax:def_global_tree: array _params$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: array _params$1=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=Params(x=ConcreteArray(ExportedGlobalArray(@_params$0 : tensor<3x4xf32>), dtype=float32), b=ConcreteArray(ExportedGlobalArray(@_params$1 : tensor<3x4xf32>), dtype=float32))
DEBUG:iree_jax:def_global_tree: array _x$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=ExportedGlobalArray(@_params$0 : tensor<3x4xf32>)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
m = TrivialKernel()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
m = TrivialKernel()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
nanobind: leaked 66 instances!
nanobind: leaked 16 types!
- leaked type "iree._runtime.VmVariantList"
- leaked type "iree._runtime.HalBufferView"
- leaked type "iree._runtime.BufferUsage"
- leaked type "iree._runtime.VmContext"
- leaked type "iree._runtime.MappedMemory"
- leaked type "iree._runtime.ArgumentPacker"
- leaked type "iree._runtime.HalElementType"
- leaked type "iree._runtime.VmRef"
- leaked type "iree._runtime.VmModule"
- leaked type "iree._runtime.HalDevice"
- leaked type "iree._runtime._InvokeStatics"
- ... skipped remainder
nanobind: leaked 78 functions!
- leaked function ""
- leaked function "lookup_function"
- leaked function "__eq__"
- leaked function ""
- leaked function "__iree_vm_type__"
- leaked function "__or__"
- leaked function "__init__"
- leaked function "create_device_by_uri"
- leaked function ""
- leaked function "invoke"
- leaked function "__init__"
- ... skipped remainder
nanobind: this is likely caused by a reference counting issue in the binding code.
error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/trivial_kernel.py
error: command failed with exit status: 2
--
********************
FAIL: IREE_JAX :: program/fft.py (2 of 5)
******************** TEST 'IREE_JAX :: program/fft.py' FAILED ********************
Script:
--
: 'RUN: at line 15'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/fft.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/fft.py
--
Exit Code: 2
Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/fft.py"
# command stderr:
DEBUG:iree_jax:Create new Program subclass: f_f_t
DEBUG:root:DEFINE PY_ONLY: _fft = <Exportable Pure Func: <function FFT._fft at 0x7f92544a2290>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
m = FFT()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
return self._fft(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
m = FFT()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
return self._fft(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/fft.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/fft.py
error: command failed with exit status: 2
--
********************
PASS: IREE_JAX :: program/trivial_globals.py (3 of 5)
FAIL: IREE_JAX :: program/duplicate_helper.py (4 of 5)
******************** TEST 'IREE_JAX :: program/duplicate_helper.py' FAILED ********************
Script:
--
: 'RUN: at line 1'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/duplicate_helper.py
--
Exit Code: 1
Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/duplicate_helper.py"
# command stderr:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
print(str(Program.get_mlir_module(module)))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
info = Program.get_info(Program._get_instance(m))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
m = m()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
return mdl._encode(x, y)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
print(str(Program.get_mlir_module(module)))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
info = Program.get_info(Program._get_instance(m))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
m = m()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
return mdl._encode(x, y)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
error: command failed with exit status: 1
--
********************
FAIL: IREE_JAX :: program/program_api_test.py (5 of 5)
******************** TEST 'IREE_JAX :: program/program_api_test.py' FAILED ********************
Script:
--
: 'RUN: at line 1'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/program_api_test.py
--
Exit Code: 1
Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/program_api_test.py"
# command stderr:
.DEBUG:iree_jax:Create new Program subclass: hidden
.DEBUG:iree_jax:Create new Program subclass: nullary
DEBUG:iree_jax:Create new Program subclass: unary
.DEBUG:iree_jax:Create new Program subclass: Foobar
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: global
.DEBUG:iree_jax:Create new Program subclass: my_subclass
./home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information.
warnings.warn(
DEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_flax_frozen_dict.<locals>.IreeJaxProgram._f at 0x7f673b4e7760>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
EDEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_list.<locals>.IreeJaxProgram._f at 0x7f673b5384c0>>
E
======================================================================
ERROR: test_value_tracing_with_flax_frozen_dict (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
unittest.main()
File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
======================================================================
ERROR: test_value_tracing_with_list (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
unittest.main()
File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
----------------------------------------------------------------------
Ran 12 tests in 0.035s
FAILED (errors=2)
error: command failed with exit status: 1
--
********************
********************
Failed Tests (4):
IREE_JAX :: program/duplicate_helper.py
IREE_JAX :: program/fft.py
IREE_JAX :: program/program_api_test.py
IREE_JAX :: program/trivial_kernel.py
Testing Time: 0.73s
Passed: 1
Failed: 4
The text was updated successfully, but these errors were encountered:
Hello, I'm trying to install iree-jax to test GPT-2 on IREE. After running
python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html
, I built jaxlib from source. However, when I runlit -v tests/
, I get a RuntimeError with the message "Unknown backend iree". This also happens when running models/gpt2/test_jax.py. Did I miss something during the setup process? Your help would be greatly appreciated. I have attached the error log below.The text was updated successfully, but these errors were encountered: