Skip to content

Commit 7abcd36

Browse files
committed
Search for nvdisasm in the conda environment / CUDA_HOME / default CUDA install
location in addition to the path. Fixes NVIDIA#9.
1 parent d7be1f0 commit 7abcd36

File tree

4 files changed

+75
-9
lines changed

4 files changed

+75
-9
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numba.core.codegen import Codegen, CodeLibrary
55
from .cudadrv import devices, driver, nvvm, runtime
66
from numba.cuda.cudadrv.libs import get_cudalib
7+
from numba.cuda.cuda_paths import get_cuda_paths
78

89
import os
910
import subprocess
@@ -24,7 +25,8 @@ def run_nvdisasm(cubin, flags):
2425
f.write(cubin)
2526

2627
try:
27-
cp = subprocess.run(['nvdisasm', *flags, fname], check=True,
28+
nvdisasm = get_cuda_paths()['nvdisasm'].info
29+
cp = subprocess.run([nvdisasm, *flags, fname], check=True,
2830
stdout=subprocess.PIPE,
2931
stderr=subprocess.PIPE)
3032
except FileNotFoundError as e:

numba_cuda/numba/cuda/cuda_paths.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
import os
44
from collections import namedtuple
5+
from shutil import which
56

67
from numba.core.config import IS_WIN32
78
from numba.misc.findlib import find_lib, find_file
@@ -10,6 +11,26 @@
1011
_env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])
1112

1213

14+
def find_executable(name, bindirs=None):
15+
# This should probably go into numba.misc.findlib
16+
if bindirs is None:
17+
return which(name) # Check the path if we've been given no directories.
18+
19+
if isinstance(bindirs, str):
20+
bindirs = [bindirs,]
21+
else:
22+
bindirs = list(bindirs)
23+
files = []
24+
for bdir in bindirs:
25+
try:
26+
entries = os.listdir(bdir)
27+
except FileNotFoundError:
28+
continue
29+
candidates = [os.path.join(bdir, ent) for ent in entries if name == ent]
30+
files.extend([c for c in candidates if os.path.isfile(c)])
31+
return files
32+
33+
1334
def _find_valid_path(options):
1435
"""Find valid path from *options*, which is a list of 2-tuple of
1536
(name, path). Return first pair where *path* is not None.
@@ -52,6 +73,18 @@ def _get_nvvm_path_decision():
5273
return by, path
5374

5475

76+
def _get_nvdisasm_path_decision():
77+
options = [
78+
('Conda environment', get_conda_ctk()),
79+
('Conda environment (NVIDIA package)', get_nvidia_nvdisasm_ctk()),
80+
('CUDA_HOME', get_cuda_home('bin')),
81+
('System', get_system_ctk('bin')),
82+
('Path', None),
83+
]
84+
by, path = _find_valid_path(options)
85+
return by, path
86+
87+
5588
def _get_libdevice_paths():
5689
by, libdir = _get_libdevice_path_decision()
5790
# Search for pattern
@@ -161,6 +194,29 @@ def get_nvidia_nvvm_ctk():
161194
return os.path.dirname(max(paths))
162195

163196

197+
def get_nvidia_nvdisasm_ctk():
198+
"""Return path to directory containing the nvdisasm executable.
199+
"""
200+
is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
201+
if not is_conda_env:
202+
return
203+
204+
# Assume the existence of nvdisasm in the conda env implies that a CUDA
205+
# toolkit conda package is installed.
206+
207+
# Try the location used on Linux and the Windows 11.x packages
208+
bindir = os.path.join(sys.prefix, 'bin')
209+
if not os.path.exists(bindir) or not os.path.isdir(bindir):
210+
return
211+
212+
paths = find_executable('nvdisasm', bindir)
213+
if not paths:
214+
return
215+
216+
# Use the directory name of the max path
217+
return os.path.dirname(max(paths))
218+
219+
164220
def get_nvidia_libdevice_ctk():
165221
"""Return path to directory containing the libdevice library.
166222
"""
@@ -220,6 +276,13 @@ def _get_nvvm_path():
220276
return _env_path_tuple(by, path)
221277

222278

279+
def _get_nvdisasm_path():
280+
by, path = _get_nvdisasm_path_decision()
281+
candidates = find_executable('nvdisasm', path)
282+
path = max(candidates) if candidates else None
283+
return _env_path_tuple(by, path)
284+
285+
223286
def get_cuda_paths():
224287
"""Returns a dictionary mapping component names to a 2-tuple
225288
of (source_variable, info).
@@ -238,9 +301,10 @@ def get_cuda_paths():
238301
# Not in cache
239302
d = {
240303
'nvvm': _get_nvvm_path(),
304+
'nvdisasm': _get_nvdisasm_path(),
241305
'libdevice': _get_libdevice_paths(),
242306
'cudalib_dir': _get_cudalib_dir(),
243-
'static_cudalib_dir': _get_static_cudalib_dir(),
307+
'static_cudalib_dir': _get_static_cudalib_dir()
244308
}
245309
# Cache result
246310
get_cuda_paths._cached_result = d

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,15 @@ def inspect_sass_cfg(self):
247247
'''
248248
Returns the CFG of the SASS for this kernel.
249249
250-
Requires nvdisasm to be available on the PATH.
250+
Requires nvdisasm to be available.
251251
'''
252252
return self._codelibrary.get_sass_cfg()
253253

254254
def inspect_sass(self):
255255
'''
256256
Returns the SASS code for this kernel.
257257
258-
Requires nvdisasm to be available on the PATH.
258+
Requires nvdisasm to be available.
259259
'''
260260
return self._codelibrary.get_sass()
261261

@@ -995,7 +995,7 @@ def inspect_sass_cfg(self, signature=None):
995995
996996
The CFG for the device in the current context is returned.
997997
998-
Requires nvdisasm to be available on the PATH.
998+
Requires nvdisasm to be available.
999999
'''
10001000
if self.targetoptions.get('device'):
10011001
raise RuntimeError('Cannot get the CFG of a device function')
@@ -1017,7 +1017,7 @@ def inspect_sass(self, signature=None):
10171017
10181018
SASS for the device in the current context is returned.
10191019
1020-
Requires nvdisasm to be available on the PATH.
1020+
Requires nvdisasm to be available.
10211021
'''
10221022
if self.targetoptions.get('device'):
10231023
raise RuntimeError('Cannot inspect SASS of a device function')

numba_cuda/numba/cuda/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shutil
44

55
from numba.tests.support import SerialMixin
6-
from numba.cuda.cuda_paths import get_conda_ctk
6+
from numba.cuda.cuda_paths import get_cuda_paths, get_conda_ctk
77
from numba.cuda.cudadrv import driver, devices, libs
88
from numba.core import config
99
from numba.tests.support import TestCase
@@ -97,12 +97,12 @@ def skip_under_cuda_memcheck(reason):
9797

9898

9999
def skip_without_nvdisasm(reason):
100-
nvdisasm_path = shutil.which('nvdisasm')
100+
nvdisasm_path = get_cuda_paths()['nvdisasm'].info
101101
return unittest.skipIf(nvdisasm_path is None, reason)
102102

103103

104104
def skip_with_nvdisasm(reason):
105-
nvdisasm_path = shutil.which('nvdisasm')
105+
nvdisasm_path = get_cuda_paths()['nvdisasm'].info
106106
return unittest.skipIf(nvdisasm_path is not None, reason)
107107

108108

0 commit comments

Comments
 (0)