Skip to content

Commit acfe654

Browse files
authored
Fix #938: Call win32 APIs directly (#942)
* Fix #938: Call win32 APIs directly * Address comments from PR * Address comments from PR * Remove APIs * Don't check return type * Address comments in PR
1 parent 6daacba commit acfe654

File tree

10 files changed

+1405
-3106
lines changed

10 files changed

+1405
-3106
lines changed

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 1115 additions & 2783 deletions
Large diffs are not rendered by default.

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 50 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
# This code was automatically generated with version 13.0.0. Do not modify it directly.
55
{{if 'Windows' == platform.system()}}
66
import os
7-
import win32api
7+
cimport cuda.bindings._lib.windll as windll
88
{{else}}
99
cimport cuda.bindings._lib.dlfcn as dlfcn
10-
from libc.stdint cimport uintptr_t
1110
{{endif}}
1211
from cuda.pathfinder import load_nvidia_dynamic_lib
13-
from libc.stdint cimport intptr_t
12+
from libc.stdint cimport intptr_t, uintptr_t
1413
import threading
1514

1615
cdef object __symbol_lock = threading.Lock()
@@ -50,172 +49,100 @@ cdef int _cuPythonInit() except -1 nogil:
5049

5150
# Load function
5251
{{if 'nvrtcGetErrorString' in found_functions}}
53-
try:
54-
global __nvrtcGetErrorString
55-
__nvrtcGetErrorString = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetErrorString')
56-
except:
57-
pass
52+
global __nvrtcGetErrorString
53+
__nvrtcGetErrorString = windll.GetProcAddress(handle, 'nvrtcGetErrorString')
5854
{{endif}}
5955
{{if 'nvrtcVersion' in found_functions}}
60-
try:
61-
global __nvrtcVersion
62-
__nvrtcVersion = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcVersion')
63-
except:
64-
pass
56+
global __nvrtcVersion
57+
__nvrtcVersion = windll.GetProcAddress(handle, 'nvrtcVersion')
6558
{{endif}}
6659
{{if 'nvrtcGetNumSupportedArchs' in found_functions}}
67-
try:
68-
global __nvrtcGetNumSupportedArchs
69-
__nvrtcGetNumSupportedArchs = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetNumSupportedArchs')
70-
except:
71-
pass
60+
global __nvrtcGetNumSupportedArchs
61+
__nvrtcGetNumSupportedArchs = windll.GetProcAddress(handle, 'nvrtcGetNumSupportedArchs')
7262
{{endif}}
7363
{{if 'nvrtcGetSupportedArchs' in found_functions}}
74-
try:
75-
global __nvrtcGetSupportedArchs
76-
__nvrtcGetSupportedArchs = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetSupportedArchs')
77-
except:
78-
pass
64+
global __nvrtcGetSupportedArchs
65+
__nvrtcGetSupportedArchs = windll.GetProcAddress(handle, 'nvrtcGetSupportedArchs')
7966
{{endif}}
8067
{{if 'nvrtcCreateProgram' in found_functions}}
81-
try:
82-
global __nvrtcCreateProgram
83-
__nvrtcCreateProgram = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcCreateProgram')
84-
except:
85-
pass
68+
global __nvrtcCreateProgram
69+
__nvrtcCreateProgram = windll.GetProcAddress(handle, 'nvrtcCreateProgram')
8670
{{endif}}
8771
{{if 'nvrtcDestroyProgram' in found_functions}}
88-
try:
89-
global __nvrtcDestroyProgram
90-
__nvrtcDestroyProgram = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcDestroyProgram')
91-
except:
92-
pass
72+
global __nvrtcDestroyProgram
73+
__nvrtcDestroyProgram = windll.GetProcAddress(handle, 'nvrtcDestroyProgram')
9374
{{endif}}
9475
{{if 'nvrtcCompileProgram' in found_functions}}
95-
try:
96-
global __nvrtcCompileProgram
97-
__nvrtcCompileProgram = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcCompileProgram')
98-
except:
99-
pass
76+
global __nvrtcCompileProgram
77+
__nvrtcCompileProgram = windll.GetProcAddress(handle, 'nvrtcCompileProgram')
10078
{{endif}}
10179
{{if 'nvrtcGetPTXSize' in found_functions}}
102-
try:
103-
global __nvrtcGetPTXSize
104-
__nvrtcGetPTXSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetPTXSize')
105-
except:
106-
pass
80+
global __nvrtcGetPTXSize
81+
__nvrtcGetPTXSize = windll.GetProcAddress(handle, 'nvrtcGetPTXSize')
10782
{{endif}}
10883
{{if 'nvrtcGetPTX' in found_functions}}
109-
try:
110-
global __nvrtcGetPTX
111-
__nvrtcGetPTX = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetPTX')
112-
except:
113-
pass
84+
global __nvrtcGetPTX
85+
__nvrtcGetPTX = windll.GetProcAddress(handle, 'nvrtcGetPTX')
11486
{{endif}}
11587
{{if 'nvrtcGetCUBINSize' in found_functions}}
116-
try:
117-
global __nvrtcGetCUBINSize
118-
__nvrtcGetCUBINSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetCUBINSize')
119-
except:
120-
pass
88+
global __nvrtcGetCUBINSize
89+
__nvrtcGetCUBINSize = windll.GetProcAddress(handle, 'nvrtcGetCUBINSize')
12190
{{endif}}
12291
{{if 'nvrtcGetCUBIN' in found_functions}}
123-
try:
124-
global __nvrtcGetCUBIN
125-
__nvrtcGetCUBIN = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetCUBIN')
126-
except:
127-
pass
92+
global __nvrtcGetCUBIN
93+
__nvrtcGetCUBIN = windll.GetProcAddress(handle, 'nvrtcGetCUBIN')
12894
{{endif}}
12995
{{if 'nvrtcGetLTOIRSize' in found_functions}}
130-
try:
131-
global __nvrtcGetLTOIRSize
132-
__nvrtcGetLTOIRSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetLTOIRSize')
133-
except:
134-
pass
96+
global __nvrtcGetLTOIRSize
97+
__nvrtcGetLTOIRSize = windll.GetProcAddress(handle, 'nvrtcGetLTOIRSize')
13598
{{endif}}
13699
{{if 'nvrtcGetLTOIR' in found_functions}}
137-
try:
138-
global __nvrtcGetLTOIR
139-
__nvrtcGetLTOIR = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetLTOIR')
140-
except:
141-
pass
100+
global __nvrtcGetLTOIR
101+
__nvrtcGetLTOIR = windll.GetProcAddress(handle, 'nvrtcGetLTOIR')
142102
{{endif}}
143103
{{if 'nvrtcGetOptiXIRSize' in found_functions}}
144-
try:
145-
global __nvrtcGetOptiXIRSize
146-
__nvrtcGetOptiXIRSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetOptiXIRSize')
147-
except:
148-
pass
104+
global __nvrtcGetOptiXIRSize
105+
__nvrtcGetOptiXIRSize = windll.GetProcAddress(handle, 'nvrtcGetOptiXIRSize')
149106
{{endif}}
150107
{{if 'nvrtcGetOptiXIR' in found_functions}}
151-
try:
152-
global __nvrtcGetOptiXIR
153-
__nvrtcGetOptiXIR = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetOptiXIR')
154-
except:
155-
pass
108+
global __nvrtcGetOptiXIR
109+
__nvrtcGetOptiXIR = windll.GetProcAddress(handle, 'nvrtcGetOptiXIR')
156110
{{endif}}
157111
{{if 'nvrtcGetProgramLogSize' in found_functions}}
158-
try:
159-
global __nvrtcGetProgramLogSize
160-
__nvrtcGetProgramLogSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetProgramLogSize')
161-
except:
162-
pass
112+
global __nvrtcGetProgramLogSize
113+
__nvrtcGetProgramLogSize = windll.GetProcAddress(handle, 'nvrtcGetProgramLogSize')
163114
{{endif}}
164115
{{if 'nvrtcGetProgramLog' in found_functions}}
165-
try:
166-
global __nvrtcGetProgramLog
167-
__nvrtcGetProgramLog = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetProgramLog')
168-
except:
169-
pass
116+
global __nvrtcGetProgramLog
117+
__nvrtcGetProgramLog = windll.GetProcAddress(handle, 'nvrtcGetProgramLog')
170118
{{endif}}
171119
{{if 'nvrtcAddNameExpression' in found_functions}}
172-
try:
173-
global __nvrtcAddNameExpression
174-
__nvrtcAddNameExpression = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcAddNameExpression')
175-
except:
176-
pass
120+
global __nvrtcAddNameExpression
121+
__nvrtcAddNameExpression = windll.GetProcAddress(handle, 'nvrtcAddNameExpression')
177122
{{endif}}
178123
{{if 'nvrtcGetLoweredName' in found_functions}}
179-
try:
180-
global __nvrtcGetLoweredName
181-
__nvrtcGetLoweredName = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetLoweredName')
182-
except:
183-
pass
124+
global __nvrtcGetLoweredName
125+
__nvrtcGetLoweredName = windll.GetProcAddress(handle, 'nvrtcGetLoweredName')
184126
{{endif}}
185127
{{if 'nvrtcGetPCHHeapSize' in found_functions}}
186-
try:
187-
global __nvrtcGetPCHHeapSize
188-
__nvrtcGetPCHHeapSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetPCHHeapSize')
189-
except:
190-
pass
128+
global __nvrtcGetPCHHeapSize
129+
__nvrtcGetPCHHeapSize = windll.GetProcAddress(handle, 'nvrtcGetPCHHeapSize')
191130
{{endif}}
192131
{{if 'nvrtcSetPCHHeapSize' in found_functions}}
193-
try:
194-
global __nvrtcSetPCHHeapSize
195-
__nvrtcSetPCHHeapSize = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcSetPCHHeapSize')
196-
except:
197-
pass
132+
global __nvrtcSetPCHHeapSize
133+
__nvrtcSetPCHHeapSize = windll.GetProcAddress(handle, 'nvrtcSetPCHHeapSize')
198134
{{endif}}
199135
{{if 'nvrtcGetPCHCreateStatus' in found_functions}}
200-
try:
201-
global __nvrtcGetPCHCreateStatus
202-
__nvrtcGetPCHCreateStatus = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetPCHCreateStatus')
203-
except:
204-
pass
136+
global __nvrtcGetPCHCreateStatus
137+
__nvrtcGetPCHCreateStatus = windll.GetProcAddress(handle, 'nvrtcGetPCHCreateStatus')
205138
{{endif}}
206139
{{if 'nvrtcGetPCHHeapSizeRequired' in found_functions}}
207-
try:
208-
global __nvrtcGetPCHHeapSizeRequired
209-
__nvrtcGetPCHHeapSizeRequired = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcGetPCHHeapSizeRequired')
210-
except:
211-
pass
140+
global __nvrtcGetPCHHeapSizeRequired
141+
__nvrtcGetPCHHeapSizeRequired = windll.GetProcAddress(handle, 'nvrtcGetPCHHeapSizeRequired')
212142
{{endif}}
213143
{{if 'nvrtcSetFlowCallback' in found_functions}}
214-
try:
215-
global __nvrtcSetFlowCallback
216-
__nvrtcSetFlowCallback = <void*><unsigned long long>win32api.GetProcAddress(handle, 'nvrtcSetFlowCallback')
217-
except:
218-
pass
144+
global __nvrtcSetFlowCallback
145+
__nvrtcSetFlowCallback = windll.GetProcAddress(handle, 'nvrtcSetFlowCallback')
219146
{{endif}}
220147

221148
{{else}}

cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,31 @@ cdef extern from "<dlfcn.h>" nogil:
3232

3333
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3434

35+
cdef int get_cuda_version():
36+
cdef void* handle = NULL
37+
cdef int err, driver_ver = 0
38+
39+
# Load driver to check version
40+
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
41+
if handle == NULL:
42+
err_msg = dlerror()
43+
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
44+
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
45+
if cuDriverGetVersion == NULL:
46+
raise RuntimeError('something went wrong')
47+
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
48+
if err != 0:
49+
raise RuntimeError('something went wrong')
50+
51+
return driver_ver
52+
3553

3654
###############################################################################
3755
# Wrapper init
3856
###############################################################################
3957

4058
cdef object __symbol_lock = threading.Lock()
4159
cdef bint __py_cufile_init = False
42-
cdef void* __cuDriverGetVersion = NULL
4360

4461
cdef void* __cuFileHandleRegister = NULL
4562
cdef void* __cuFileHandleDeregister = NULL
@@ -97,24 +114,9 @@ cdef int _check_or_init_cufile() except -1 nogil:
97114
return 0
98115

99116
cdef void* handle = NULL
100-
cdef int err, driver_ver = 0
101117

102118
with gil, __symbol_lock:
103-
# Load driver to check version
104-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
105-
if handle == NULL:
106-
err_msg = dlerror()
107-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
108-
global __cuDriverGetVersion
109-
if __cuDriverGetVersion == NULL:
110-
__cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
111-
if __cuDriverGetVersion == NULL:
112-
raise RuntimeError('something went wrong')
113-
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
114-
if err != 0:
115-
raise RuntimeError('something went wrong')
116-
#dlclose(handle)
117-
handle = NULL
119+
driver_ver = get_cuda_version()
118120

119121
# Load function
120122
global __cuFileHandleRegister

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,31 @@ cdef extern from "<dlfcn.h>" nogil:
3030

3131
const void* RTLD_DEFAULT 'RTLD_DEFAULT'
3232

33+
cdef int get_cuda_version():
34+
cdef void* handle = NULL
35+
cdef int err, driver_ver = 0
36+
37+
# Load driver to check version
38+
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
39+
if handle == NULL:
40+
err_msg = dlerror()
41+
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
42+
cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
43+
if cuDriverGetVersion == NULL:
44+
raise RuntimeError('something went wrong')
45+
err = (<int (*)(int*) noexcept nogil>cuDriverGetVersion)(&driver_ver)
46+
if err != 0:
47+
raise RuntimeError('something went wrong')
48+
49+
return driver_ver
50+
3351

3452
###############################################################################
3553
# Wrapper init
3654
###############################################################################
3755

3856
cdef object __symbol_lock = threading.Lock()
3957
cdef bint __py_nvjitlink_init = False
40-
cdef void* __cuDriverGetVersion = NULL
4158

4259
cdef void* __nvJitLinkCreate = NULL
4360
cdef void* __nvJitLinkDestroy = NULL
@@ -66,24 +83,9 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
6683
return 0
6784

6885
cdef void* handle = NULL
69-
cdef int err, driver_ver = 0
7086

7187
with gil, __symbol_lock:
72-
# Load driver to check version
73-
handle = dlopen('libcuda.so.1', RTLD_NOW | RTLD_GLOBAL)
74-
if handle == NULL:
75-
err_msg = dlerror()
76-
raise NotSupportedError(f'CUDA driver is not found ({err_msg.decode()})')
77-
global __cuDriverGetVersion
78-
if __cuDriverGetVersion == NULL:
79-
__cuDriverGetVersion = dlsym(handle, "cuDriverGetVersion")
80-
if __cuDriverGetVersion == NULL:
81-
raise RuntimeError('something went wrong')
82-
err = (<int (*)(int*) noexcept nogil>__cuDriverGetVersion)(&driver_ver)
83-
if err != 0:
84-
raise RuntimeError('something went wrong')
85-
#dlclose(handle)
86-
handle = NULL
88+
driver_ver = get_cuda_version()
8789

8890
# Load function
8991
global __nvJitLinkCreate

0 commit comments

Comments
 (0)