You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:
$ git clone https://github.com/google/iree-jax.git
$ cd iree-jax
$ python -m pip install -e .[test,xla,cpu] -f https://github.com/google/iree/releases
$ python --version
Python 3.9.7
$ vi tiny.py
Copy the example in README. Add the following two statement:
import jax.numpy as jnp
from collections import namedtuple
$ python tiny.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/u89062/iree-jax/examples/tiny.py", line 85, in <module>
m = TrivialKernel()
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 559, in __new__
export_function()
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 554, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/u89062/iree-jax/iree/jax/exporter.py", line 208, in def_func
return_py_value = f(*argument_py_tree)
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 552, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/u89062/iree-jax/examples/tiny.py", line 61, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/u89062/iree-jax/iree/jax/tracing.py", line 55, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/u89062/iree-jax/iree/jax/tracing.py", line 115, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/u89062/iree-jax/iree/jax/builtins.py", line 64, in resolve_call
imported_main_symbol_name = jax_utils.import_main_function(
File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 116, in import_main_function
source_module = import_module(context, source_module)
File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 95, in import_module
raise ValueError(
ValueError: Attempted to import a non-module (did you enable MLIR in JAX?). Got module @jit__linear.2 {
func public @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<3x4xf32>) -> tensor<3x4xf32> {
%0 = mhlo.multiply %arg0, %arg1 : tensor<3x4xf32>
%1 = mhlo.add %0, %arg2 : tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
}
Thanks!
The text was updated successfully, but these errors were encountered:
Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:
Thanks!
The text was updated successfully, but these errors were encountered: