Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search for nvdisasm in the conda environment / CUDA_HOME / default CUDA install location in addition to the path #13

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion numba_cuda/numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba.core.codegen import Codegen, CodeLibrary
from .cudadrv import devices, driver, nvvm, runtime
from numba.cuda.cudadrv.libs import get_cudalib
from numba.cuda.cuda_paths import get_cuda_paths

import os
import subprocess
Expand All @@ -24,7 +25,8 @@ def run_nvdisasm(cubin, flags):
f.write(cubin)

try:
cp = subprocess.run(['nvdisasm', *flags, fname], check=True,
nvdisasm = get_cuda_paths()['nvdisasm'].info
cp = subprocess.run([nvdisasm, *flags, fname], check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
except FileNotFoundError as e:
Expand Down
66 changes: 65 additions & 1 deletion numba_cuda/numba/cuda/cuda_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import os
from collections import namedtuple
from shutil import which

from numba.core.config import IS_WIN32
from numba.misc.findlib import find_lib, find_file
Expand All @@ -10,6 +11,26 @@
_env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])


def find_executable(name, bindirs=None):
# This should probably go into numba.misc.findlib
if bindirs is None:
return which(name) # Check the path if we've been given no directories.

if isinstance(bindirs, str):
bindirs = [bindirs,]
else:
bindirs = list(bindirs)
files = []
for bdir in bindirs:
try:
entries = os.listdir(bdir)
except FileNotFoundError:
continue
candidates = [os.path.join(bdir, ent) for ent in entries if name == ent]
files.extend([c for c in candidates if os.path.isfile(c)])
return files


def _find_valid_path(options):
"""Find valid path from *options*, which is a list of 2-tuple of
(name, path). Return first pair where *path* is not None.
Expand Down Expand Up @@ -52,6 +73,18 @@ def _get_nvvm_path_decision():
return by, path


def _get_nvdisasm_path_decision():
options = [
('Conda environment', get_conda_ctk()),
('Conda environment (NVIDIA package)', get_nvidia_nvdisasm_ctk()),
('CUDA_HOME', get_cuda_home('bin')),
('System', get_system_ctk('bin')),
('Path', None),
]
by, path = _find_valid_path(options)
return by, path


def _get_libdevice_paths():
by, libdir = _get_libdevice_path_decision()
# Search for pattern
Expand Down Expand Up @@ -161,6 +194,29 @@ def get_nvidia_nvvm_ctk():
return os.path.dirname(max(paths))


def get_nvidia_nvdisasm_ctk():
"""Return path to directory containing the nvdisasm executable.
"""
is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
if not is_conda_env:
return

# Assume the existence of nvdisasm in the conda env implies that a CUDA
# toolkit conda package is installed.

# Try the location used on Linux and the Windows 11.x packages
bindir = os.path.join(sys.prefix, 'bin')
if not os.path.exists(bindir) or not os.path.isdir(bindir):
return

paths = find_executable('nvdisasm', bindir)
if not paths:
return

# Use the directory name of the max path
return os.path.dirname(max(paths))


def get_nvidia_libdevice_ctk():
"""Return path to directory containing the libdevice library.
"""
Expand Down Expand Up @@ -220,6 +276,13 @@ def _get_nvvm_path():
return _env_path_tuple(by, path)


def _get_nvdisasm_path():
by, path = _get_nvdisasm_path_decision()
candidates = find_executable('nvdisasm', path)
path = max(candidates) if candidates else None
return _env_path_tuple(by, path)


def get_cuda_paths():
"""Returns a dictionary mapping component names to a 2-tuple
of (source_variable, info).
Expand All @@ -238,9 +301,10 @@ def get_cuda_paths():
# Not in cache
d = {
'nvvm': _get_nvvm_path(),
'nvdisasm': _get_nvdisasm_path(),
'libdevice': _get_libdevice_paths(),
'cudalib_dir': _get_cudalib_dir(),
'static_cudalib_dir': _get_static_cudalib_dir(),
'static_cudalib_dir': _get_static_cudalib_dir()
}
# Cache result
get_cuda_paths._cached_result = d
Expand Down
8 changes: 4 additions & 4 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ def inspect_sass_cfg(self):
'''
Returns the CFG of the SASS for this kernel.

Requires nvdisasm to be available on the PATH.
Requires nvdisasm to be available.
'''
return self._codelibrary.get_sass_cfg()

def inspect_sass(self):
'''
Returns the SASS code for this kernel.

Requires nvdisasm to be available on the PATH.
Requires nvdisasm to be available.
'''
return self._codelibrary.get_sass()

Expand Down Expand Up @@ -995,7 +995,7 @@ def inspect_sass_cfg(self, signature=None):

The CFG for the device in the current context is returned.

Requires nvdisasm to be available on the PATH.
Requires nvdisasm to be available.
'''
if self.targetoptions.get('device'):
raise RuntimeError('Cannot get the CFG of a device function')
Expand All @@ -1017,7 +1017,7 @@ def inspect_sass(self, signature=None):

SASS for the device in the current context is returned.

Requires nvdisasm to be available on the PATH.
Requires nvdisasm to be available.
'''
if self.targetoptions.get('device'):
raise RuntimeError('Cannot inspect SASS of a device function')
Expand Down
6 changes: 3 additions & 3 deletions numba_cuda/numba/cuda/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shutil

from numba.tests.support import SerialMixin
from numba.cuda.cuda_paths import get_conda_ctk
from numba.cuda.cuda_paths import get_cuda_paths, get_conda_ctk
from numba.cuda.cudadrv import driver, devices, libs
from numba.core import config
from numba.tests.support import TestCase
Expand Down Expand Up @@ -97,12 +97,12 @@ def skip_under_cuda_memcheck(reason):


def skip_without_nvdisasm(reason):
nvdisasm_path = shutil.which('nvdisasm')
nvdisasm_path = get_cuda_paths()['nvdisasm'].info
return unittest.skipIf(nvdisasm_path is None, reason)


def skip_with_nvdisasm(reason):
nvdisasm_path = shutil.which('nvdisasm')
nvdisasm_path = get_cuda_paths()['nvdisasm'].info
return unittest.skipIf(nvdisasm_path is not None, reason)


Expand Down
Loading