diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 6f04773c3f..30bfb84bcb 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -9,6 +9,7 @@ import hashlib import graphlib import sys, os +from types import ModuleType from typing import Callable from collections import deque import numpy as np @@ -1296,60 +1297,65 @@ def visit_Call(self, node): """ global globalRegisteredOperations - if self.verbose: - print("[Visit Call] {}".format( - ast.unparse(node) if hasattr(ast, 'unparse') else node)) + #if self.verbose: + print("[Visit Call] {}".format( + ast.unparse(node) if hasattr(ast, 'unparse') else node)) self.currentNode = node - # do not walk the FunctionDef decorator_list arguments if isinstance(node.func, ast.Attribute): - if hasattr( - node.func.value, 'id' - ) and node.func.value.id == 'cudaq' and node.func.attr == 'kernel': - return + # When `node.func` is an attribute, then we have the case where the + # call has the following form: `..<...>.`. + value = node.func.value + + # First, we walk all the components until we reach a name. + components = [node.func.attr] + while isinstance(value, ast.Attribute): + components.append(value.attr) + value = value.value + components.append(value.id) + components = components[::-1] - # If we have a `func = ast.Attribute``, then it could be that - # we have a previously defined kernel function call with manually specified module names - # e.g. `cudaq.lib.test.hello.fermionic_swap``. In this case, we assume - # FindDepKernels has found something like this, loaded it, and now we just - # want to get the function name and call it. + # Check whether this is our knonw decorator `@cudaq.kernel`. If it + # is then we gracefully ignore it. + if components[0] == 'cudaq' and components[1] == 'kernel': + return - # First let's check for registered C++ kernels - cppDevModNames = [] - value = node.func.value - if isinstance(value, ast.Name) and value.id != 'cudaq': - cppDevModNames = [node.func.attr, value.id] - else: - while isinstance(value, ast.Attribute): - cppDevModNames.append(value.attr) - value = value.value - if isinstance(value, ast.Name): - cppDevModNames.append(value.id) - break - - devKey = '.'.join(cppDevModNames[::-1]) - - def get_full_module_path(partial_path): - parts = partial_path.split('.') - for module_name, module in sys.modules.items(): - if module_name.endswith(parts[0]): - try: - obj = module - for part in parts[1:]: - obj = getattr(obj, part) - return f"{module_name}.{'.'.join(parts[1:])}" - except AttributeError: - continue - return partial_path - - devKey = get_full_module_path(devKey) - if cudaq_runtime.isRegisteredDeviceModule(devKey): + # Get full module path. + # + # Note: Here we skip anything that starts with `cudaq.` because not + # all constructs are backed by an python object. See issue # + mod_path = "" + if components[0] != 'cudaq': + if components[0] in sys.modules: + module = sys.modules[components[0]] + obj = module + for attribute in components[1:]: + obj = getattr(obj, attribute) + if hasattr(obj, '__module__') and obj.__module__ != obj.__name__: + mod_path = obj.__module__ + else: + mod_path = obj.__name__ + else: + import inspect + current_frame = inspect.currentframe() + mod = None + while current_frame is not None: + local_vars = current_frame.f_locals + if components[0] in local_vars: + mod = local_vars[components[0]] + break; + current_frame = current_frame.f_back + + if isinstance(mod, ModuleType): + mod_path = mod.__name__ + + if cudaq_runtime.isRegisteredDeviceModule(mod_path): maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey + '.' + node.func.attr) + self.module, mod_path + '.' + node.func.attr) if maybeKernelName == None: maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey) + self.module, mod_path) if maybeKernelName != None: otherKernel = SymbolTable( self.module.operation)[maybeKernelName] @@ -1368,7 +1374,6 @@ def get_full_module_path(partial_path): func.CallOp(otherKernel, values) return - # Start by seeing if we have mod1.mod2.mod3... moduleNames = [] value = node.func.value while isinstance(value, ast.Attribute): diff --git a/python/tests/interop/qlib.py b/python/tests/interop/qlib.py new file mode 100644 index 0000000000..cea8b1861a --- /dev/null +++ b/python/tests/interop/qlib.py @@ -0,0 +1,10 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from cudaq_test_cpp_algo import * + diff --git a/python/tests/interop/test_interop.py b/python/tests/interop/test_interop.py index 78d46e576e..0d52d8e4e7 100644 --- a/python/tests/interop/test_interop.py +++ b/python/tests/interop/test_interop.py @@ -222,6 +222,27 @@ def callUCCSD(): callUCCSD() +def test_cpp_kernel_from_python_3(): + + import qlib + + # Sanity checks + print(qlib.qstd.qft) + print(qlib.qstd.another) + + @cudaq.kernel + def callQftAndAnother(): + q = cudaq.qvector(4) + qlib.qstd.qft(q) + h(q) + qlib.qstd.another(q, 2) + + callQftAndAnother() + + counts = cudaq.sample(callQftAndAnother) + counts.dump() + assert len(counts) == 1 and '0010' in counts + def test_capture(): @cudaq.kernel def takesCapture(s : int): @@ -232,4 +253,4 @@ def takesCapture(s : int): @cudaq.kernel(verbose=True) def entry(): takesCapture(spin) - entry.compile() \ No newline at end of file + entry.compile()