Skip to content

Commit a6f8bcb

Browse files
committed
Add memory resource argument through to_device
While the `to_device` function already included a memory resource, it didn't use it. Plus other functions calling `to_device` did not use the argument. The change here makes sure `to_device` passes this to the `DeviceBuffer` constructor. Also it makes sure other functions have a default argument, which they set if one is not specified.
1 parent 1532bff commit a6f8bcb

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

python/rmm/_lib/device_buffer.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,17 @@ cdef class DeviceBuffer:
182182

183183
@staticmethod
184184
cdef DeviceBuffer c_to_device(const unsigned char[::1] b,
185-
Stream stream=DEFAULT_STREAM):
185+
Stream stream=DEFAULT_STREAM,
186+
DeviceMemoryResource mr=None):
186187
"""Calls ``to_device`` function on arguments provided"""
187-
return to_device(b, stream)
188+
return to_device(b, stream, mr)
188189

189190
@staticmethod
190191
def to_device(const unsigned char[::1] b,
191-
Stream stream=DEFAULT_STREAM):
192+
Stream stream=DEFAULT_STREAM,
193+
DeviceMemoryResource mr=None):
192194
"""Calls ``to_device`` function on arguments provided."""
193-
return to_device(b, stream)
195+
return to_device(b, stream, mr)
194196

195197
cpdef copy_to_host(self, ary=None, Stream stream=DEFAULT_STREAM):
196198
"""Copy from a ``DeviceBuffer`` to a buffer on host.
@@ -356,7 +358,8 @@ cdef class DeviceBuffer:
356358

357359
@cython.boundscheck(False)
358360
cpdef DeviceBuffer to_device(const unsigned char[::1] b,
359-
Stream stream=DEFAULT_STREAM):
361+
Stream stream=DEFAULT_STREAM,
362+
DeviceMemoryResource mr=None):
360363
"""Return a new ``DeviceBuffer`` with a copy of the data.
361364
362365
Parameters
@@ -384,7 +387,7 @@ cpdef DeviceBuffer to_device(const unsigned char[::1] b,
384387

385388
cdef uintptr_t p = <uintptr_t>&b[0]
386389
cdef size_t s = len(b)
387-
return DeviceBuffer(ptr=p, size=s, stream=stream)
390+
return DeviceBuffer(ptr=p, size=s, stream=stream, mr=mr)
388391

389392

390393
@cython.boundscheck(False)

0 commit comments

Comments
 (0)