@@ -56,6 +56,7 @@ cdef class _cyBuffer:
5656 size_t _size
5757 _cyMemoryResource _mr
5858 object _ptr_obj
59+ cyStream _alloc_stream
5960
6061
6162cdef class _cyMemoryResource:
@@ -107,22 +108,31 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
107108 """
108109 cdef dict __dict__ # required if inheriting from both Cython/Python classes
109110
111+ def __cinit__ (self ):
112+ self ._ptr = 0
113+ self ._size = 0
114+ self ._mr = None
115+ self ._ptr_obj = None
116+ self ._alloc_stream = None
117+
110118 def __init__ (self , *args , **kwargs ):
111119 raise RuntimeError (" Buffer objects cannot be instantiated directly. Please use MemoryResource APIs." )
112120
113121 @classmethod
114- def _init (cls , ptr: DevicePointerT , size_t size , mr: MemoryResource | None = None ):
122+ def _init (cls , ptr: DevicePointerT , size_t size , mr: MemoryResource | None = None , stream: Stream | None = None ):
115123 cdef Buffer self = Buffer.__new__ (cls )
116124 self ._ptr = < intptr_t> (int (ptr))
117125 self ._ptr_obj = ptr
118126 self ._size = size
119127 self ._mr = mr
128+ self ._alloc_stream = < cyStream> (stream) if stream is not None else None
120129 return self
121130
122131 def __dealloc__ (self ):
123- self .close()
132+ self .close(self ._alloc_stream )
124133
125134 def __reduce__ (self ):
135+ # Must not serialize the parent's stream!
126136 return Buffer.from_ipc_descriptor, (self .memory_resource, self .get_ipc_descriptor())
127137
128138 cpdef close(self , stream: Stream = None ):
@@ -137,15 +147,21 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
137147 The stream object to use for asynchronous deallocation. If None,
138148 the behavior depends on the underlying memory resource.
139149 """
150+ cdef cyStream s
140151 if self ._ptr and self ._mr is not None :
141- # To be fixed in NVIDIA/cuda-python#1032
142152 if stream is None :
143- stream = Stream.__new__ (Stream)
144- (< cyStream> (stream))._handle = < cydriver.CUstream> (0 )
145- self ._mr._deallocate(self ._ptr, self ._size, < cyStream> stream)
153+ if self ._alloc_stream is not None :
154+ s = self ._alloc_stream
155+ else :
156+ # TODO: remove this branch when from_handle takes a stream
157+ s = < cyStream> (default_stream())
158+ else :
159+ s = < cyStream> stream
160+ self ._mr._deallocate(self ._ptr, self ._size, s)
146161 self ._ptr = 0
147162 self ._mr = None
148163 self ._ptr_obj = None
164+ self ._alloc_stream = None
149165
150166 @property
151167 def handle (self ) -> DevicePointerT:
@@ -206,16 +222,19 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
206222 return IPCBufferDescriptor._init(data_b , self.size )
207223
208224 @classmethod
209- def from_ipc_descriptor(cls , mr: DeviceMemoryResource , ipc_buffer: IPCBufferDescriptor ) -> Buffer:
225+ def from_ipc_descriptor(cls , mr: DeviceMemoryResource , ipc_buffer: IPCBufferDescriptor , stream: Stream = None ) -> Buffer:
210226 """Import a buffer that was exported from another process."""
211227 if not mr.is_ipc_enabled:
212228 raise RuntimeError("Memory resource is not IPC-enabled")
229+ if stream is None:
230+ # Note: match this behavior to DeviceMemoryResource.allocate()
231+ stream = default_stream()
213232 cdef cydriver.CUmemPoolPtrExportData share_data
214233 memcpy(share_data.reserved , <const void*><const char*>(ipc_buffer._reserved ), sizeof(share_data.reserved ))
215234 cdef cydriver.CUdeviceptr ptr
216235 with nogil:
217236 HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr , mr._mempool_handle , &share_data ))
218- return Buffer.from_handle (<intptr_t>ptr , ipc_buffer.size , mr )
237+ return Buffer._init (<intptr_t>ptr , ipc_buffer.size , mr , stream )
219238
220239 def copy_to(self , dst: Buffer = None , *, stream: Stream ) -> Buffer:
221240 """Copy from this buffer to the dst buffer asynchronously on the given stream.
@@ -336,6 +355,7 @@ cdef class Buffer(_cyBuffer, MemoryResourceAttributes):
336355 mr : :obj:`~_memory.MemoryResource`, optional
337356 Memory resource associated with the buffer
338357 """
358+ # TODO: It is better to take a stream for latter deallocation
339359 return Buffer._init(ptr , size , mr = mr)
340360
341361
@@ -839,7 +859,7 @@ cdef class DeviceMemoryResource(MemoryResource):
839859 cdef int handle = int (alloc_handle)
840860 with nogil:
841861 HANDLE_RETURN(cydriver.cuMemPoolImportFromShareableHandle(
842- &(self._mempool_handle ), <void*>handle , _IPC_HANDLE_TYPE , 0)
862+ &(self._mempool_handle ), <void*><intptr_t>( handle ) , _IPC_HANDLE_TYPE , 0)
843863 )
844864 if uuid is not None:
845865 registered = self .register(uuid)
@@ -889,6 +909,7 @@ cdef class DeviceMemoryResource(MemoryResource):
889909 buf._ptr_obj = None
890910 buf._size = size
891911 buf._mr = self
912+ buf._alloc_stream = stream
892913 return buf
893914
894915 def allocate (self , size_t size , stream: Stream = None ) -> Buffer:
@@ -931,10 +952,9 @@ cdef class DeviceMemoryResource(MemoryResource):
931952 The size of the buffer to deallocate, in bytes.
932953 stream : Stream, optional
933954 The stream on which to perform the deallocation asynchronously.
934- If None, an internal stream is used.
955+ If the buffer is deallocated without an explicit stream, the allocation stream
956+ is used.
935957 """
936- if stream is None :
937- stream = default_stream()
938958 self ._deallocate(< intptr_t> ptr, size, < cyStream> stream)
939959
940960 @property
@@ -1017,11 +1037,13 @@ class LegacyPinnedMemoryResource(MemoryResource):
10171037 Buffer
10181038 The allocated buffer object , which is accessible on both host and device.
10191039 """
1040+ if stream is None:
1041+ stream = default_stream()
10201042 err , ptr = driver.cuMemAllocHost(size)
10211043 raise_if_driver_error(err )
1022- return Buffer._init(ptr , size , self )
1044+ return Buffer._init(ptr , size , self , stream )
10231045
1024- def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream = None ):
1046+ def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream ):
10251047 """ Deallocate a buffer previously allocated by this resource.
10261048
10271049 Parameters
@@ -1030,12 +1052,10 @@ class LegacyPinnedMemoryResource(MemoryResource):
10301052 The pointer or handle to the buffer to deallocate.
10311053 size : int
10321054 The size of the buffer to deallocate, in bytes.
1033- stream : Stream, optional
1034- The stream on which to perform the deallocation asynchronously.
1035- If None, no synchronization would happen.
1055+ stream : Stream
1056+ The stream on which to perform the deallocation synchronously.
10361057 """
1037- if stream:
1038- stream.sync()
1058+ stream.sync()
10391059 err, = driver.cuMemFreeHost(ptr)
10401060 raise_if_driver_error(err)
10411061
@@ -1063,13 +1083,13 @@ class _SynchronousMemoryResource(MemoryResource):
10631083 self ._dev_id = getattr (device_id, ' device_id' , device_id)
10641084
10651085 def allocate (self , size , stream = None ) -> Buffer:
1086+ if stream is None:
1087+ stream = default_stream()
10661088 err , ptr = driver.cuMemAlloc(size)
10671089 raise_if_driver_error(err )
10681090 return Buffer._init(ptr , size , self )
10691091
1070- def deallocate(self , ptr , size , stream = None ):
1071- if stream is None :
1072- stream = default_stream()
1092+ def deallocate(self , ptr , size , stream ):
10731093 stream.sync()
10741094 err, = driver.cuMemFree(ptr)
10751095 raise_if_driver_error(err)
0 commit comments