Skip to content

Commit ee32f67

Browse files
committed
_decide_nvjitlink_or_driver(): catch RuntimeError (bug fix), use importlib + ModuleNotFoundError (more selective than ImportError) and produce specific error messages
1 parent c4bb623 commit ee32f67

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from __future__ import annotations
66

77
import ctypes
8+
import importlib
9+
import sys
810
import weakref
911
from contextlib import contextmanager
1012
from dataclasses import dataclass
@@ -37,28 +39,34 @@ def _decide_nvjitlink_or_driver() -> bool:
3739

3840
_driver_ver = handle_return(driver.cuDriverGetVersion())
3941
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
42+
43+
def libname():
44+
return "nvJitLink*.dll" if sys.platform == "win32" else "libnvJitLink.so*"
45+
46+
therefore_not_usable = ". Therefore cuda.bindings.nvjitlink is not usable and"
47+
4048
try:
41-
from cuda.bindings import nvjitlink as _nvjitlink
42-
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
43-
except ImportError:
44-
# binding is not available
45-
_nvjitlink = None
49+
_nvjitlink = importlib.import_module("cuda.bindings.nvjitlink")
50+
except ModuleNotFoundError:
51+
problem = "cuda.bindings.nvjitlink is not available, therefore"
52+
except RuntimeError:
53+
problem = libname() + " is not available" + therefore_not_usable
4654
else:
55+
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
56+
4757
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
48-
# binding is available, but nvJitLink is not installed
49-
_nvjitlink = None
50-
51-
if _nvjitlink is None:
52-
warn(
53-
"nvJitLink is not installed or too old (<12.3). Therefore it is not usable "
54-
"and the culink APIs will be used instead.",
55-
stacklevel=3,
56-
category=RuntimeWarning,
57-
)
58-
_driver = driver
59-
return True
60-
else:
61-
return False
58+
return False # Use nvjitlink
59+
60+
problem = libname() + " is is too old (<12.3)" + therefore_not_usable
61+
_nvjitlink = None
62+
63+
warn(
64+
problem + " the culink APIs will be used instead.",
65+
stacklevel=2,
66+
category=RuntimeWarning,
67+
)
68+
_driver = driver
69+
return True
6270

6371

6472
def _lazy_init():

0 commit comments

Comments
 (0)