Skip to content

Commit d12cc2c

Browse files
committed
Add LIBNAMES_REQUIRING_RTLD_DEEPBIND feature (for cufftMp)
1 parent 37b8822 commit d12cc2c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

cuda_pathfinder/cuda/pathfinder/_dynamic_libs/load_dl_linux.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from typing import Optional, cast
99

1010
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
11-
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import SUPPORTED_LINUX_SONAMES
11+
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
12+
LIBNAMES_REQUIRING_RTLD_DEEPBIND,
13+
SUPPORTED_LINUX_SONAMES,
14+
)
1215

1316
CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
1417

@@ -138,6 +141,13 @@ def check_if_already_loaded_from_elsewhere(libname: str, _have_abs_path: bool) -
138141
return None
139142

140143

144+
def _load_lib(libname: str, filename: str) -> ctypes.CDLL:
145+
cdll_mode = CDLL_MODE
146+
if libname in LIBNAMES_REQUIRING_RTLD_DEEPBIND:
147+
cdll_mode |= os.RTLD_DEEPBIND
148+
return ctypes.CDLL(filename, cdll_mode)
149+
150+
141151
def load_with_system_search(libname: str) -> Optional[LoadedDL]:
142152
"""Try to load a library using system search paths.
143153
@@ -152,13 +162,14 @@ def load_with_system_search(libname: str) -> Optional[LoadedDL]:
152162
"""
153163
for soname in get_candidate_sonames(libname):
154164
try:
155-
handle = ctypes.CDLL(soname, CDLL_MODE)
165+
handle = _load_lib(libname, soname)
166+
except OSError:
167+
pass
168+
else:
156169
abs_path = abs_path_for_dynamic_library(libname, handle)
157170
if abs_path is None:
158171
raise RuntimeError(f"No expected symbol for {libname=!r}")
159172
return LoadedDL(abs_path, False, handle._handle)
160-
except OSError:
161-
pass
162173
return None
163174

164175

@@ -196,7 +207,7 @@ def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
196207
"""
197208
_work_around_known_bugs(libname, found_path)
198209
try:
199-
handle = ctypes.CDLL(found_path, CDLL_MODE)
210+
handle = _load_lib(libname, found_path)
200211
except OSError as e:
201212
raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
202213
return LoadedDL(found_path, False, handle._handle)

cuda_pathfinder/cuda/pathfinder/_dynamic_libs/supported_nvidia_libs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@
394394
"nvrtc",
395395
)
396396

397+
LIBNAMES_REQUIRING_RTLD_DEEPBIND = ("cufftMp",)
398+
397399

398400
def is_suppressed_dll_file(path_basename: str) -> bool:
399401
if path_basename.startswith("nvrtc"):

0 commit comments

Comments
 (0)