Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Attempted to import a non-module #5

Open
ronghongbo opened this issue Feb 9, 2022 · 0 comments
Open

ValueError: Attempted to import a non-module #5

ronghongbo opened this issue Feb 9, 2022 · 0 comments

Comments

@ronghongbo
Copy link

Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:

$ git clone https://github.com/google/iree-jax.git
$ cd iree-jax
$ python -m pip install -e .[test,xla,cpu] -f https://github.com/google/iree/releases
$ python --version
Python 3.9.7
$ vi tiny.py
Copy the example in README. Add the following two statement:
     import jax.numpy as jnp
     from collections import namedtuple
$ python tiny.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/u89062/iree-jax/examples/tiny.py", line 85, in <module>
    m = TrivialKernel()
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 559, in __new__
    export_function()
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 554, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/u89062/iree-jax/iree/jax/exporter.py", line 208, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 552, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/u89062/iree-jax/examples/tiny.py", line 61, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/u89062/iree-jax/iree/jax/tracing.py", line 55, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/u89062/iree-jax/iree/jax/tracing.py", line 115, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/u89062/iree-jax/iree/jax/builtins.py", line 64, in resolve_call
    imported_main_symbol_name = jax_utils.import_main_function(
  File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 116, in import_main_function
    source_module = import_module(context, source_module)
  File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 95, in import_module
    raise ValueError(
ValueError: Attempted to import a non-module (did you enable MLIR in JAX?). Got module @jit__linear.2 {
  func public @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<3x4xf32>) -> tensor<3x4xf32> {
    %0 = mhlo.multiply %arg0, %arg1 : tensor<3x4xf32>
    %1 = mhlo.add %0, %arg2 : tensor<3x4xf32>
    return %1 : tensor<3x4xf32>
  }
}

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant