Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 93 additions & 19 deletions cuda_core/build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# - https://setuptools.pypa.io/en/latest/build_meta.html#dynamic-build-dependencies-and-other-build-meta-tweaks
# Specifically, there are 5 APIs required to create a proper build backend, see below.

import ctypes
import functools
import glob
import os
import pathlib
import re
import subprocess
import sys

from Cython.Build import cythonize
from setuptools import Extension
Expand All @@ -23,6 +25,88 @@
get_requires_for_build_sdist = _build_meta.get_requires_for_build_sdist


@functools.cache
def _get_cuda_paths():
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
if CUDA_PATH is None:
return None
Comment on lines +31 to +32
Copy link
Member

@leofang leofang Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I changed this to return None if no env var is set, but I think it is a mistake because we still need paths to the headers when creating Extension instances below. I think it is time to evaluate using pathfinder. It is already a build-time dependency anyway (cuda-bindings brought it in).

FWIW it seems #692 suddenly becomes within reach!

CUDA_PATH = CUDA_PATH.split(os.pathsep)
print("CUDA paths:", CUDA_PATH, flush=True)
return CUDA_PATH


@functools.cache
def _get_cuda_version_from_cuda_h(cuda_home=None):
"""
Given CUDA_HOME, try to extract the CUDA_VERSION macro from include/cuda.h.

Example line in cuda.h:
#define CUDA_VERSION 13000

Returns the integer (e.g. 13000) or None if not found / on error.
"""
if cuda_home is None:
cuda_home = _get_cuda_paths()
if cuda_home is None:
return None
else:
cuda_home = cuda_home[0]

cuda_h = pathlib.Path(cuda_home) / "include" / "cuda.h"
if not cuda_h.is_file():
return None

try:
text = cuda_h.read_text(encoding="utf-8", errors="ignore")
except OSError:
# Permissions issue, unreadable file, etc.
return None

m = re.search(r"^\s*#define\s+CUDA_VERSION\s+(\d+)", text, re.MULTILINE)
if not m:
return None
print(f"CUDA_VERSION from {cuda_h}:", m.group(1), flush=True)
return int(m.group(1))


def _get_cuda_driver_version():
"""
Try to load ``libcuda.so`` or ``nvcuda.dll`` via standard dynamic library lookup
and call ``cuDriverGetVersion``.

Returns the integer (e.g. 13000) or None if not found / on error.
"""
CUDA_SUCCESS = 0

if sys.platform == "win32":
try:
# WinDLL => stdcall (CUDAAPI on Windows), matches CUDA Driver API.
lib = ctypes.WinDLL("nvcuda.dll")
except OSError:
return None
else:
cdll_mode = os.RTLD_NOW | os.RTLD_GLOBAL
try:
# Use system search paths only; do not provide an absolute path.
# Make symbols globally available to any dependent libraries.
lib = ctypes.CDLL("libcuda.so.1", mode=cdll_mode)
except OSError:
return None

# int cuDriverGetVersion(int* driverVersion);
cuDriverGetVersion = lib.cuDriverGetVersion
cuDriverGetVersion.restype = ctypes.c_int # CUresult
cuDriverGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]

out = ctypes.c_int(0)
rc = cuDriverGetVersion(ctypes.byref(out))
if rc != CUDA_SUCCESS:
return None

print("CUDA_VERSION from driver:", int(out.value), flush=True)
return int(out.value)


@functools.cache
def _get_proper_cuda_bindings_major_version() -> str:
# for local development (with/without build isolation)
Expand All @@ -38,15 +122,14 @@ def _get_proper_cuda_bindings_major_version() -> str:
if cuda_major is not None:
return cuda_major

cuda_version = _get_cuda_version_from_cuda_h()
if cuda_version:
return str(cuda_version // 1000)

# also for local development
try:
out = subprocess.run("nvidia-smi", env=os.environ, capture_output=True, check=True) # noqa: S603, S607
m = re.search(r"CUDA Version:\s*([\d\.]+)", out.stdout.decode())
if m:
return m.group(1).split(".")[0]
except (FileNotFoundError, subprocess.CalledProcessError):
# the build machine has no driver installed
pass
cuda_version = _get_cuda_driver_version()
if cuda_version:
return str(cuda_version // 1000)
Comment on lines +130 to +132
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we must need CUDA headers at build time, this is no longer needed in terms of deciding cuda-bindings version, but it can be a useful check in case there is a local driver installation that cannot support the package being built. We can raise a nice warning in such cases via this check.


# default fallback
return "13"
Expand Down Expand Up @@ -75,20 +158,11 @@ def strip_prefix_suffix(filename):

module_names = (strip_prefix_suffix(f) for f in ext_files)

@functools.cache
def get_cuda_paths():
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
if not CUDA_PATH:
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
CUDA_PATH = CUDA_PATH.split(os.pathsep)
print("CUDA paths:", CUDA_PATH)
return CUDA_PATH

ext_modules = tuple(
Extension(
f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}",
sources=[f"cuda/core/experimental/{mod}.pyx"],
include_dirs=list(os.path.join(root, "include") for root in get_cuda_paths()),
include_dirs=list(os.path.join(root, "include") for root in _get_cuda_paths()),
language="c++",
)
for mod in module_names
Expand Down