From 92aade8c3dca281abbc8126a92357707dbd68151 Mon Sep 17 00:00:00 2001 From: boschmitt <7152025+boschmitt@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:10:16 +0100 Subject: [PATCH] Fix logic to find real module path With python device kernel interoperability, users can write quantum kernels in C++ and bind them to python. In such cases, the common pattern is to have a C++ module that gets imported into a python module. For example, if we have a python package named `foo` to which we add C++ extensions using pybind11. The common pattern is to end up with a with a module named `_cppfoo` (or whaterver). Then, we import all of its symbols to `foo`: foo/__init__.py: from ._cppfoo import * Now, if `_cppfoo` contains a binded device kernel named `bar`, then users are able to access it using `foo.bar(...)`. This, however, is not the real path of `bar`, the real path is `foo._cppfoo.bar(..)`. Curently, binded device kernels get registered with their real path name, and thus when the python AST bridge parse another kernel that uses `foo.bar(...)`, it needs to figure it out if that is its real path or not. This commit attemps to improve the robustness of discovering this real path because as-is it fails on some simple cases. This is how it works: In Python, many objects have a module attribute, which indicates the module in which the object was defined. This should be the case for functions. Thus the idea here is to walk the provide path until we reach the function object and ask it for its `__module__`. --- python/cudaq/kernel/ast_bridge.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 6f04773c3f..caa54d3977 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -1332,18 +1332,22 @@ def visit_Call(self, node): def get_full_module_path(partial_path): parts = partial_path.split('.') - for module_name, module in sys.modules.items(): - if module_name.endswith(parts[0]): - try: - obj = module - for part in parts[1:]: - obj = getattr(obj, part) - return f"{module_name}.{'.'.join(parts[1:])}" - except AttributeError: - continue - return partial_path - - devKey = get_full_module_path(devKey) + if len(parts[0]) == 0: + return partial_path + if parts[0] in sys.modules: + module = sys.modules[parts[0]] + else: + raise ImportError(f"Module '{parts[0]}' is not imported") + obj = module + for attribute in parts[1:]: + obj = getattr(obj, attribute) + if hasattr(obj, + '__module__') and obj.__module__ != obj.__name__: + return f"{obj.__module__}" + return f"{module.__name__}" + + devKey = get_full_module_path("{}.{}".format( + devKey, node.func.attr)) if cudaq_runtime.isRegisteredDeviceModule(devKey): maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( self.module, devKey + '.' + node.func.attr)