diff --git a/cuda_core/cuda/core/experimental/_linker.py b/cuda_core/cuda/core/experimental/_linker.py index eb850b5973..a3fa4b3e48 100644 --- a/cuda_core/cuda/core/experimental/_linker.py +++ b/cuda_core/cuda/core/experimental/_linker.py @@ -5,6 +5,7 @@ from __future__ import annotations import ctypes +import sys import weakref from contextlib import contextmanager from dataclasses import dataclass @@ -28,6 +29,11 @@ _nvjitlink_input_types = None # populated if nvJitLink cannot be used +def _nvjitlink_has_version_symbol(inner_nvjitlink) -> bool: + # This condition is equivalent to testing for version >= 12.3 + return bool(inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) + + # Note: this function is reused in the tests def _decide_nvjitlink_or_driver() -> bool: """Returns True if falling back to the cuLink* driver APIs.""" @@ -37,28 +43,36 @@ def _decide_nvjitlink_or_driver() -> bool: _driver_ver = handle_return(driver.cuDriverGetVersion()) _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) + + warn_txt_common = ( + "the driver APIs will be used instead, which do not support" + " minor version compatibility or linking LTO IRs." + " For best results, consider upgrading to a recent version of" + ) + try: - from cuda.bindings import nvjitlink as _nvjitlink - from cuda.bindings._internal import nvjitlink as inner_nvjitlink - except ImportError: - # binding is not available - _nvjitlink = None + import cuda.bindings.nvjitlink as _nvjitlink + except ModuleNotFoundError: + warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." else: - if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0: - # binding is available, but nvJitLink is not installed - _nvjitlink = None - - if _nvjitlink is None: - warn( - "nvJitLink is not installed or too old (<12.3). Therefore it is not usable " - "and the culink APIs will be used instead.", - stacklevel=3, - category=RuntimeWarning, + from cuda.bindings._internal import nvjitlink as inner_nvjitlink + + try: + if _nvjitlink_has_version_symbol(inner_nvjitlink): + return False # Use nvjitlink + except RuntimeError: + warn_detail = "not available" + else: + warn_detail = "too old (<12.3)" + warn_txt = ( + f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is {warn_detail}." + f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." ) - _driver = driver - return True - else: - return False + _nvjitlink = None + + warn(warn_txt, stacklevel=2, category=RuntimeWarning) + _driver = driver + return True def _lazy_init():