Skip to content

Commit 96ce480

Browse files
committed
cythonize event
1 parent 9519904 commit 96ce480

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
from __future__ import annotations
66

7+
from libc.stdint cimport uintptr_t
8+
9+
# TODO: how about cuda.bindings < 12.6.2?
10+
from cuda.bindings cimport cydriver
11+
712
from cuda.core.experimental._utils.cuda_utils cimport (
813
_check_driver_error as raise_if_driver_error,
914
check_or_create_options,
@@ -78,43 +83,46 @@ cdef class Event:
7883
7984
"""
8085
cdef:
81-
object _handle
86+
cydriver.CUevent _handle
8287
bint _timing_disabled
8388
bint _busy_waited
8489
int _device_id
8590
object _ctx_handle
8691

92+
def __cinit__(self):
93+
self._handle = <cydriver.CUevent>(NULL)
94+
8795
def __init__(self, *args, **kwargs):
8896
raise RuntimeError("Event objects cannot be instantiated directly. Please use Stream APIs (record).")
8997

9098
@classmethod
9199
def _init(cls, device_id: int, ctx_handle: Context, options=None):
92100
cdef Event self = Event.__new__(cls)
93101
cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options")
94-
flags = 0x0
102+
cdef unsigned int flags = 0x0
95103
self._timing_disabled = False
96104
self._busy_waited = False
97105
if not opts.enable_timing:
98-
flags |= driver.CUevent_flags.CU_EVENT_DISABLE_TIMING
106+
flags |= cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING
99107
self._timing_disabled = True
100108
if opts.busy_waited_sync:
101-
flags |= driver.CUevent_flags.CU_EVENT_BLOCKING_SYNC
109+
flags |= cydriver.CUevent_flags.CU_EVENT_BLOCKING_SYNC
102110
self._busy_waited = True
103111
if opts.support_ipc:
104112
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
105-
err, self._handle = driver.cuEventCreate(flags)
106-
raise_if_driver_error(err)
113+
# TODO: use HANDLE_RETURN
114+
err = cydriver.cuEventCreate(&self._handle, flags)
107115
self._device_id = device_id
108116
self._ctx_handle = ctx_handle
109117
return self
110118

111119
cdef _shutdown_safe_close(self, is_shutting_down=sys.is_finalizing):
112120
if is_shutting_down and is_shutting_down():
113121
return
114-
if self._handle is not None:
115-
err, = driver.cuEventDestroy(self._handle)
116-
self._handle = None
117-
raise_if_driver_error(err)
122+
if self._handle != NULL:
123+
# TODO: use HANDLE_RETURN
124+
err = cydriver.cuEventDestroy(self._handle)
125+
self._handle = <cydriver.CUevent>(NULL)
118126

119127
cpdef close(self):
120128
"""Destroy the event."""
@@ -129,14 +137,14 @@ cdef class Event:
129137
def __rsub__(self, other):
130138
return NotImplemented
131139

132-
def __sub__(self, other):
140+
def __sub__(self, other: Event):
133141
# return self - other (in milliseconds)
134-
err, timing = driver.cuEventElapsedTime(other.handle, self._handle)
135-
try:
136-
raise_if_driver_error(err)
142+
cdef float timing
143+
err = cydriver.cuEventElapsedTime(&timing, other._handle, self._handle)
144+
if err == 0:
137145
return timing
138-
except CUDAError as e:
139-
if err == driver.CUresult.CUDA_ERROR_INVALID_HANDLE:
146+
else:
147+
if err == cydriver.CUresult.CUDA_ERROR_INVALID_HANDLE:
140148
if self.is_timing_disabled or other.is_timing_disabled:
141149
explanation = (
142150
"Both Events must be created with timing enabled in order to subtract them; "
@@ -147,15 +155,15 @@ cdef class Event:
147155
"Both Events must be recorded before they can be subtracted; "
148156
"use Stream.record() to record both events to a stream."
149157
)
150-
elif err == driver.CUresult.CUDA_ERROR_NOT_READY:
158+
elif err == cydriver.CUresult.CUDA_ERROR_NOT_READY:
151159
explanation = (
152160
"One or both events have not completed; "
153161
"use Event.sync(), Stream.sync(), or Device.sync() to wait for the events to complete "
154162
"before subtracting them."
155163
)
156164
else:
157-
raise e
158-
raise RuntimeError(explanation) from e
165+
raise CUDAError(err)
166+
raise RuntimeError(explanation)
159167

160168
@property
161169
def is_timing_disabled(self) -> bool:
@@ -182,17 +190,18 @@ cdef class Event:
182190
has been completed.
183191
184192
"""
185-
handle_return(driver.cuEventSynchronize(self._handle))
193+
# TODO: use HANDLE_RETURN
194+
err = cydriver.cuEventSynchronize(self._handle)
186195

187196
@property
188197
def is_done(self) -> bool:
189198
"""Return True if all captured works have been completed, otherwise False."""
190-
result, = driver.cuEventQuery(self._handle)
191-
if result == driver.CUresult.CUDA_SUCCESS:
199+
result = cydriver.cuEventQuery(self._handle)
200+
if result == cydriver.CUresult.CUDA_SUCCESS:
192201
return True
193-
if result == driver.CUresult.CUDA_ERROR_NOT_READY:
202+
if result == cydriver.CUresult.CUDA_ERROR_NOT_READY:
194203
return False
195-
handle_return(result)
204+
# TODO: use HANDLE_RETURN
196205

197206
@property
198207
def handle(self) -> cuda.bindings.driver.CUevent:
@@ -203,7 +212,7 @@ cdef class Event:
203212
This handle is a Python object. To get the memory address of the underlying C
204213
handle, call ``int(Event.handle)``.
205214
"""
206-
return self._handle
215+
return driver.CUevent(<uintptr_t>(self._handle))
207216

208217
@property
209218
def device(self) -> Device:

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ cdef class Stream:
122122
object _device_id
123123
object _ctx_handle
124124

125-
def __cinit__(self, *args, **kwargs):
125+
def __cinit__(self):
126126
self._handle = <cydriver.CUstream>(NULL)
127127

128128
def __init__(self, *args, **kwargs):
@@ -235,7 +235,7 @@ cdef class Stream:
235235
This handle is a Python object. To get the memory address of the underlying C
236236
handle, call ``int(Stream.handle)``.
237237
"""
238-
return driver.CUstream(<uintptr_t><void*>(self._handle))
238+
return driver.CUstream(<uintptr_t>(self._handle))
239239

240240
@property
241241
def is_nonblocking(self) -> bool:

0 commit comments

Comments
 (0)