Skip to content

Commit cc44883

Browse files
committed
Address comments in PR
1 parent 43b71f0 commit cc44883

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

cuda_bindings/cuda/bindings/_lib/windll.pxd

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cdef extern from "windows.h" nogil:
1212
ctypedef unsigned long DWORD
1313
ctypedef const wchar_t *LPCWSTR
1414
ctypedef const char *LPCSTR
15+
ctypedef int BOOL
1516

1617
cdef DWORD LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
1718

@@ -23,6 +24,8 @@ cdef extern from "windows.h" nogil:
2324

2425
FARPROC _GetProcAddress "GetProcAddress"(HMODULE hModule, LPCSTR lpProcName)
2526

27+
BOOL _FreeLibrary "FreeLibrary"(HMODULE hLibModule)
28+
2629
cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
2730
cdef uintptr_t result
2831
cdef wchar_t* wpath = PyUnicode_AsWideCharString(path, NULL)
@@ -37,3 +40,6 @@ cdef inline uintptr_t LoadLibraryExW(str path, HANDLE hFile, DWORD dwFlags):
3740

3841
cdef inline FARPROC GetProcAddress(uintptr_t hModule, const char* lpProcName) nogil:
3942
return _GetProcAddress(<HMODULE>hModule, lpProcName)
43+
44+
cdef inline BOOL FreeLibrary(uintptr_t hLibModule) nogil:
45+
return _FreeLibrary(<HMODULE>hLibModule)

cuda_bindings/cuda/bindings/cyruntime.pyx.in

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,7 @@ cdef cudaError_t getLocalRuntimeVersion(int* runtimeVersion) except ?cudaErrorCa
18971897

18981898
# Load
18991899
with gil:
1900-
lib = load_nvidia_dynamic_lib("cudart")
1900+
loaded_dl = load_nvidia_dynamic_lib("cudart")
19011901
{{if 'Windows' == platform.system()}}
19021902
handle = <uintptr_t>lib._handle_uint
19031903
{{else}}
@@ -1912,14 +1912,16 @@ cdef cudaError_t getLocalRuntimeVersion(int* runtimeVersion) except ?cudaErrorCa
19121912

19131913
if __cudaRuntimeGetVersion == NULL:
19141914
with gil:
1915-
raise RuntimeError(f'Function "cudaRuntimeGetVersion" not found in {lib.abs_path}')
1915+
raise RuntimeError(f'Function "cudaRuntimeGetVersion" not found in {loaded_dl.abs_path}')
19161916

19171917
# Call
19181918
cdef cudaError_t err = cudaSuccess
19191919
err = (<cudaError_t (*)(int*) except ?cudaErrorCallRequiresNewerDriver nogil> __cudaRuntimeGetVersion)(runtimeVersion)
19201920

19211921
# Unload
1922-
{{if 'Windows' != platform.system()}}
1922+
{{if 'Windows' == platform.system()}}
1923+
windll.FreeLibrary(handle)
1924+
{{else}}
19231925
dlfcn.dlclose(handle)
19241926
{{endif}}
19251927

cuda_bindings/tests/test_cudart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,7 @@ def test_getLocalRuntimeVersion():
14071407
try:
14081408
err, version = cudart.getLocalRuntimeVersion()
14091409
except pathfinder.DynamicLibNotFoundError:
1410-
pass
1410+
pytest.skip("cudart dynamic lib not available")
14111411
else:
14121412
assertSuccess(err)
14131413
assert version >= 12000 # CUDA 12.0

0 commit comments

Comments
 (0)