|
5 | 5 | from __future__ import annotations
|
6 | 6 |
|
7 | 7 | import ctypes
|
| 8 | +import importlib |
| 9 | +import sys |
8 | 10 | import weakref
|
9 | 11 | from contextlib import contextmanager
|
10 | 12 | from dataclasses import dataclass
|
@@ -37,28 +39,34 @@ def _decide_nvjitlink_or_driver() -> bool:
|
37 | 39 |
|
38 | 40 | _driver_ver = handle_return(driver.cuDriverGetVersion())
|
39 | 41 | _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
|
| 42 | + |
| 43 | + def libname(): |
| 44 | + return "nvJitLink*.dll" if sys.platform == "win32" else "libnvJitLink.so*" |
| 45 | + |
| 46 | + therefore_not_usable = ". Therefore cuda.bindings.nvjitlink is not usable and" |
| 47 | + |
40 | 48 | try:
|
41 |
| - from cuda.bindings import nvjitlink as _nvjitlink |
42 |
| - from cuda.bindings._internal import nvjitlink as inner_nvjitlink |
43 |
| - except ImportError: |
44 |
| - # binding is not available |
45 |
| - _nvjitlink = None |
| 49 | + _nvjitlink = importlib.import_module("cuda.bindings.nvjitlink") |
| 50 | + except ModuleNotFoundError: |
| 51 | + problem = "cuda.bindings.nvjitlink is not available, therefore" |
| 52 | + except RuntimeError: |
| 53 | + problem = libname() + " is not available" + therefore_not_usable |
46 | 54 | else:
|
| 55 | + from cuda.bindings._internal import nvjitlink as inner_nvjitlink |
| 56 | + |
47 | 57 | if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
|
48 |
| - # binding is available, but nvJitLink is not installed |
49 |
| - _nvjitlink = None |
50 |
| - |
51 |
| - if _nvjitlink is None: |
52 |
| - warn( |
53 |
| - "nvJitLink is not installed or too old (<12.3). Therefore it is not usable " |
54 |
| - "and the culink APIs will be used instead.", |
55 |
| - stacklevel=3, |
56 |
| - category=RuntimeWarning, |
57 |
| - ) |
58 |
| - _driver = driver |
59 |
| - return True |
60 |
| - else: |
61 |
| - return False |
| 58 | + return False # Use nvjitlink |
| 59 | + |
| 60 | + problem = libname() + " is is too old (<12.3)" + therefore_not_usable |
| 61 | + _nvjitlink = None |
| 62 | + |
| 63 | + warn( |
| 64 | + problem + " the culink APIs will be used instead.", |
| 65 | + stacklevel=2, |
| 66 | + category=RuntimeWarning, |
| 67 | + ) |
| 68 | + _driver = driver |
| 69 | + return True |
62 | 70 |
|
63 | 71 |
|
64 | 72 | def _lazy_init():
|
|
0 commit comments