Skip to content

Commit 4cbb627

Browse files
authored
Ensure allocation stream is used for buffer deallocation if no explicit stream is provided (#1032)
* release gil * release gil for record * use __dealloc__ in event/stream * nit: remove print * fix linter error * reduce further the number of Python objects held by Stream * replace a few more __del__ by __dealloc__ * improve __cuda_stream__ performance * cythonize Buffer & DMR (WIP - not working!) * minor fixes - still failing * fix casting and avoid repetive assignment * ensure we have C access for DeviceMemoryResource * restore the contracts for now (the deallocation stream should be fixed) * fully cythonize Stream * make linter happy * ensure allocation stream can be used for deallocation; suppress casting warning * int -> intptr_t -> void* is safer
1 parent ae55f7c commit 4cbb627

File tree

1 file changed

+42
-22
lines changed

1 file changed

+42
-22
lines changed

cuda_core/cuda/core/experimental/_memory.pyx

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cdef class _cyBuffer:
5656
size_t _size
5757
_cyMemoryResource _mr
5858
object _ptr_obj
59+
cyStream _alloc_stream
5960

6061

6162
cdef class _cyMemoryResource:
@@ -107,22 +108,31 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
107108
"""
108109
cdef dict __dict__ # required if inheriting from both Cython/Python classes
109110

111+
def __cinit__(self):
112+
self._ptr = 0
113+
self._size = 0
114+
self._mr = None
115+
self._ptr_obj = None
116+
self._alloc_stream = None
117+
110118
def __init__(self, *args, **kwargs):
111119
raise RuntimeError("Buffer objects cannot be instantiated directly. Please use MemoryResource APIs.")
112120

113121
@classmethod
114-
def _init(cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None):
122+
def _init(cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None, stream: Stream | None = None):
115123
cdef Buffer self = Buffer.__new__(cls)
116124
self._ptr = <intptr_t>(int(ptr))
117125
self._ptr_obj = ptr
118126
self._size = size
119127
self._mr = mr
128+
self._alloc_stream = <cyStream>(stream) if stream is not None else None
120129
return self
121130

122131
def __dealloc__(self):
123-
self.close()
132+
self.close(self._alloc_stream)
124133

125134
def __reduce__(self):
135+
# Must not serialize the parent's stream!
126136
return Buffer.from_ipc_descriptor, (self.memory_resource, self.get_ipc_descriptor())
127137

128138
cpdef close(self, stream: Stream = None):
@@ -137,15 +147,21 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
137147
The stream object to use for asynchronous deallocation. If None,
138148
the behavior depends on the underlying memory resource.
139149
"""
150+
cdef cyStream s
140151
if self._ptr and self._mr is not None:
141-
# To be fixed in NVIDIA/cuda-python#1032
142152
if stream is None:
143-
stream = Stream.__new__(Stream)
144-
(<cyStream>(stream))._handle = <cydriver.CUstream>(0)
145-
self._mr._deallocate(self._ptr, self._size, <cyStream>stream)
153+
if self._alloc_stream is not None:
154+
s = self._alloc_stream
155+
else:
156+
# TODO: remove this branch when from_handle takes a stream
157+
s = <cyStream>(default_stream())
158+
else:
159+
s = <cyStream>stream
160+
self._mr._deallocate(self._ptr, self._size, s)
146161
self._ptr = 0
147162
self._mr = None
148163
self._ptr_obj = None
164+
self._alloc_stream = None
149165

150166
@property
151167
def handle(self) -> DevicePointerT:
@@ -206,16 +222,19 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
206222
return IPCBufferDescriptor._init(data_b, self.size)
207223

208224
@classmethod
209-
def from_ipc_descriptor(cls, mr: DeviceMemoryResource, ipc_buffer: IPCBufferDescriptor) -> Buffer:
225+
def from_ipc_descriptor(cls, mr: DeviceMemoryResource, ipc_buffer: IPCBufferDescriptor, stream: Stream = None) -> Buffer:
210226
"""Import a buffer that was exported from another process."""
211227
if not mr.is_ipc_enabled:
212228
raise RuntimeError("Memory resource is not IPC-enabled")
229+
if stream is None:
230+
# Note: match this behavior to DeviceMemoryResource.allocate()
231+
stream = default_stream()
213232
cdef cydriver.CUmemPoolPtrExportData share_data
214233
memcpy(share_data.reserved, <const void*><const char*>(ipc_buffer._reserved), sizeof(share_data.reserved))
215234
cdef cydriver.CUdeviceptr ptr
216235
with nogil:
217236
HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr, mr._mempool_handle, &share_data))
218-
return Buffer.from_handle(<intptr_t>ptr, ipc_buffer.size, mr)
237+
return Buffer._init(<intptr_t>ptr, ipc_buffer.size, mr, stream)
219238

220239
def copy_to(self, dst: Buffer = None, *, stream: Stream) -> Buffer:
221240
"""Copy from this buffer to the dst buffer asynchronously on the given stream.
@@ -336,6 +355,7 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
336355
mr : :obj:`~_memory.MemoryResource`, optional
337356
Memory resource associated with the buffer
338357
"""
358+
# TODO: It is better to take a stream for latter deallocation
339359
return Buffer._init(ptr, size, mr=mr)
340360

341361

@@ -839,7 +859,7 @@ cdef class DeviceMemoryResource(MemoryResource):
839859
cdef int handle = int(alloc_handle)
840860
with nogil:
841861
HANDLE_RETURN(cydriver.cuMemPoolImportFromShareableHandle(
842-
&(self._mempool_handle), <void*>handle, _IPC_HANDLE_TYPE, 0)
862+
&(self._mempool_handle), <void*><intptr_t>(handle), _IPC_HANDLE_TYPE, 0)
843863
)
844864
if uuid is not None:
845865
registered = self.register(uuid)
@@ -889,6 +909,7 @@ cdef class DeviceMemoryResource(MemoryResource):
889909
buf._ptr_obj = None
890910
buf._size = size
891911
buf._mr = self
912+
buf._alloc_stream = stream
892913
return buf
893914

894915
def allocate(self, size_t size, stream: Stream = None) -> Buffer:
@@ -931,10 +952,9 @@ cdef class DeviceMemoryResource(MemoryResource):
931952
The size of the buffer to deallocate, in bytes.
932953
stream : Stream, optional
933954
The stream on which to perform the deallocation asynchronously.
934-
If None, an internal stream is used.
955+
If the buffer is deallocated without an explicit stream, the allocation stream
956+
is used.
935957
"""
936-
if stream is None:
937-
stream = default_stream()
938958
self._deallocate(<intptr_t>ptr, size, <cyStream>stream)
939959

940960
@property
@@ -1017,11 +1037,13 @@ class LegacyPinnedMemoryResource(MemoryResource):
10171037
Buffer
10181038
The allocated buffer object, which is accessible on both host and device.
10191039
"""
1040+
if stream is None:
1041+
stream = default_stream()
10201042
err, ptr = driver.cuMemAllocHost(size)
10211043
raise_if_driver_error(err)
1022-
return Buffer._init(ptr, size, self)
1044+
return Buffer._init(ptr, size, self, stream)
10231045

1024-
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream = None):
1046+
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream):
10251047
"""Deallocate a buffer previously allocated by this resource.
10261048
10271049
Parameters
@@ -1030,12 +1052,10 @@ class LegacyPinnedMemoryResource(MemoryResource):
10301052
The pointer or handle to the buffer to deallocate.
10311053
size : int
10321054
The size of the buffer to deallocate, in bytes.
1033-
stream : Stream, optional
1034-
The stream on which to perform the deallocation asynchronously.
1035-
If None, no synchronization would happen.
1055+
stream : Stream
1056+
The stream on which to perform the deallocation synchronously.
10361057
"""
1037-
if stream:
1038-
stream.sync()
1058+
stream.sync()
10391059
err, = driver.cuMemFreeHost(ptr)
10401060
raise_if_driver_error(err)
10411061

@@ -1063,13 +1083,13 @@ class _SynchronousMemoryResource(MemoryResource):
10631083
self._dev_id = getattr(device_id, 'device_id', device_id)
10641084

10651085
def allocate(self, size, stream=None) -> Buffer:
1086+
if stream is None:
1087+
stream = default_stream()
10661088
err, ptr = driver.cuMemAlloc(size)
10671089
raise_if_driver_error(err)
10681090
return Buffer._init(ptr, size, self)
10691091

1070-
def deallocate(self, ptr, size, stream=None):
1071-
if stream is None:
1072-
stream = default_stream()
1092+
def deallocate(self, ptr, size, stream):
10731093
stream.sync()
10741094
err, = driver.cuMemFree(ptr)
10751095
raise_if_driver_error(err)

0 commit comments

Comments
 (0)