-
Notifications
You must be signed in to change notification settings - Fork 221
Adds use_pool option to DeviceMemoryResource. #1192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -509,9 +509,15 @@ cdef class DeviceMemoryResourceOptions: | |
| max_size : int, optional | ||
| Maximum pool size. When set to 0, defaults to a system-dependent value. | ||
| (Default to 0) | ||
|
|
||
| use_pool : bool, optional | ||
| Whether to use a memory pool. Pool-based allocations cannot be captured | ||
| in a graph but are the only ones that support sharing via IPC. | ||
| (Default to True) | ||
| """ | ||
| ipc_enabled : cython.bint = False | ||
| max_size : cython.size_t = 0 | ||
| use_pool : cython.bint = True | ||
|
|
||
|
|
||
| # TODO: cythonize this? | ||
|
|
@@ -533,6 +539,8 @@ class DeviceMemoryResourceAttributes: | |
| mr = self._mr() | ||
| if mr is None: | ||
| raise RuntimeError("DeviceMemoryResource is expired") | ||
| if mr.handle is None: | ||
| raise RuntimeError("DeviceMemoryResource is not configured to use a memory pool") | ||
| # TODO: this implementation does not allow lowering to Cython + nogil | ||
| err, value = driver.cuMemPoolGetAttribute(mr.handle, attr_enum) | ||
| raise_if_driver_error(err) | ||
|
|
@@ -715,29 +723,39 @@ cdef class DeviceMemoryResource(MemoryResource): | |
| &max_threshold | ||
| )) | ||
| else: | ||
| # Create a new memory pool. | ||
| if opts.ipc_enabled and _IPC_HANDLE_TYPE == cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE: | ||
| raise RuntimeError("IPC is not available on {platform.system()}") | ||
|
|
||
| memset(&properties, 0, sizeof(cydriver.CUmemPoolProps)) | ||
| properties.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED | ||
| properties.handleTypes = _IPC_HANDLE_TYPE if opts.ipc_enabled else cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE | ||
| properties.location.id = dev_id | ||
| properties.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE | ||
| properties.maxSize = opts.max_size | ||
| properties.win32SecurityAttributes = NULL | ||
| properties.usage = 0 | ||
|
|
||
| self._dev_id = dev_id | ||
| self._ipc_handle_type = properties.handleTypes | ||
| self._mempool_owned = True | ||
| if opts.use_pool: | ||
| # Create a new memory pool. | ||
| if opts.ipc_enabled and _IPC_HANDLE_TYPE == cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE: | ||
| raise RuntimeError("IPC is not available on {platform.system()}") | ||
|
|
||
| memset(&properties, 0, sizeof(cydriver.CUmemPoolProps)) | ||
| properties.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED | ||
| properties.handleTypes = _IPC_HANDLE_TYPE if opts.ipc_enabled else cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE | ||
| properties.location.id = dev_id | ||
| properties.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE | ||
| properties.maxSize = opts.max_size | ||
| properties.win32SecurityAttributes = NULL | ||
| properties.usage = 0 | ||
|
|
||
| self._dev_id = dev_id | ||
| self._ipc_handle_type = properties.handleTypes | ||
| self._mempool_owned = True | ||
|
|
||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuMemPoolCreate(&(self._mempool_handle), &properties)) | ||
| # TODO: should we also set the threshold here? | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuMemPoolCreate(&(self._mempool_handle), &properties)) | ||
| # TODO: should we also set the threshold here? | ||
|
|
||
| if opts.ipc_enabled: | ||
| self.get_allocation_handle() # enables Buffer.get_ipc_descriptor, sets uuid | ||
| if opts.ipc_enabled: | ||
| self.get_allocation_handle() # enables Buffer.get_ipc_descriptor, sets uuid | ||
| else: | ||
| if opts.ipc_enabled: | ||
| raise RuntimeError("Cannot supply ipc_enabled=True with use_pool=False") | ||
|
Comment on lines
+751
to
+752
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Setting |
||
| self._dev_id = dev_id | ||
| self._ipc_handle_type = cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE | ||
| self._mempool_owned = False | ||
| self._is_mapped = False | ||
| self._uuid = None | ||
| self._alloc_handle = None | ||
|
|
||
| def __dealloc__(self): | ||
| self.close() | ||
|
|
@@ -887,8 +905,12 @@ cdef class DeviceMemoryResource(MemoryResource): | |
| cdef Buffer _allocate(self, size_t size, cyStream stream): | ||
| cdef cydriver.CUstream s = stream._handle | ||
| cdef cydriver.CUdeviceptr devptr | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuMemAllocFromPoolAsync(&devptr, size, self._mempool_handle, s)) | ||
| if self.is_using_pool: | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuMemAllocFromPoolAsync(&devptr, size, self._mempool_handle, s)) | ||
| else: | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuMemAllocAsync(&devptr, size, s)) | ||
| cdef Buffer buf = Buffer.__new__(Buffer) | ||
| buf._ptr = <intptr_t>(devptr) | ||
| buf._ptr_obj = None | ||
|
|
@@ -987,6 +1009,11 @@ cdef class DeviceMemoryResource(MemoryResource): | |
| """Whether this memory resource has IPC enabled.""" | ||
| return self._ipc_handle_type != cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE | ||
|
|
||
| @property | ||
| def is_using_pool(self) -> bool: | ||
| """Whether this memory resource uses a memory pool.""" | ||
| return self._mempool_handle != NULL | ||
|
|
||
|
|
||
| def _deep_reduce_device_memory_resource(mr): | ||
| from . import Device | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
|
||
| import numpy as np | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import pytest | ||
|
|
||
| try: | ||
| from cuda.bindings import nvrtc | ||
| except ImportError: | ||
| from cuda import nvrtc | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from cuda.core.experimental import ( | ||
| Device, | ||
| DeviceMemoryResource, | ||
| DeviceMemoryResourceOptions, | ||
| GraphBuilder, | ||
| GraphCompleteOptions, | ||
| GraphDebugPrintOptions, | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| LaunchConfig, | ||
| LegacyPinnedMemoryResource, | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Program, | ||
| ProgramOptions, | ||
| launch, | ||
| ) | ||
| from cuda.core.experimental._utils.cuda_utils import NVRTCError, handle_return | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from helpers.buffers import compare_equal_buffers, make_scratch_buffer | ||
|
|
||
| def _common_kernels(): | ||
| code = """ | ||
| __global__ void set_zero(char *a, size_t nbytes) { | ||
| size_t idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| size_t stride = blockDim.x * gridDim.x; | ||
| for (size_t i = idx; i < nbytes; i += stride) { | ||
| a[i] = 0; | ||
| } | ||
| } | ||
| __global__ void add_one(char *a, size_t nbytes) { | ||
| size_t idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| size_t stride = blockDim.x * gridDim.x; | ||
| for (size_t i = idx; i < nbytes; i += stride) { | ||
| a[i] += 1; | ||
| } | ||
| } | ||
| """ | ||
| arch = "".join(f"{i}" for i in Device().compute_capability) | ||
| program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") | ||
| prog = Program(code, code_type="c++", options=program_options) | ||
| mod = prog.compile("cubin", name_expressions=("set_zero", "add_one")) | ||
| return mod | ||
|
|
||
|
|
||
| # def test_no_graph(init_cuda): | ||
| # device = Device() | ||
| # stream = device.create_stream() | ||
| # | ||
| # # Get kernels. | ||
| # mod = _common_kernels() | ||
| # set_zero = mod.get_kernel("set_zero") | ||
| # add_one = mod.get_kernel("add_one") | ||
| # | ||
| # # Run operations. | ||
| # NBYTES = 1 | ||
| # mr = DeviceMemoryResource(device) | ||
| # work_buffer = mr.allocate(NBYTES, stream=stream) | ||
| # launch(stream, LaunchConfig(grid=1, block=1), set_zero, int(work_buffer.handle)) | ||
| # launch(stream, LaunchConfig(grid=1, block=1), add_one, int(work_buffer.handle)) | ||
| # | ||
| # # Check the result. | ||
| # one = make_scratch_buffer(device, 1, NBYTES) | ||
| # compare_buffer = make_scratch_buffer(device, 0, NBYTES) | ||
| # compare_buffer.copy_from(work_buffer, stream=stream) | ||
| # stream.sync() | ||
| # assert compare_equal_buffers(one, compare_buffer) | ||
|
|
||
| # # Let's have a look. | ||
| # # options = GraphDebugPrintOptions(**{field: True for field in GraphDebugPrintOptions.__dataclass_fields__}) | ||
| # # gb.debug_dot_print(b"./debug.dot", options) | ||
|
|
||
|
|
||
| def test_graph(init_cuda): | ||
| device = Device() | ||
| stream = device.create_stream() | ||
| options = DeviceMemoryResourceOptions(use_pool=False) | ||
| mr = DeviceMemoryResource(device, options=options) | ||
|
|
||
| # Get kernels. | ||
| mod = _common_kernels() | ||
| set_zero = mod.get_kernel("set_zero") | ||
| add_one = mod.get_kernel("add_one") | ||
|
|
||
| NBYTES = 64 | ||
| target = mr.allocate(NBYTES, stream=stream) | ||
|
|
||
| # Begin graph capture. | ||
| gb = Device().create_graph_builder().begin_building(mode="thread_local") | ||
Andy-Jost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # import code | ||
| # code.interact(local=dict(globals(), **locals())) | ||
| work_buffer = mr.allocate(NBYTES, stream=gb.stream) | ||
| launch(gb, LaunchConfig(grid=1, block=1), set_zero, int(work_buffer.handle), NBYTES) | ||
| launch(gb, LaunchConfig(grid=1, block=1), add_one, int(work_buffer.handle), NBYTES) | ||
| launch(gb, LaunchConfig(grid=1, block=1), add_one, int(work_buffer.handle), NBYTES) | ||
| target.copy_from(work_buffer, stream=gb.stream) | ||
|
Comment on lines
+97
to
+103
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is the sort of use we are targeting. |
||
|
|
||
| # Finalize the graph. | ||
| graph = gb.end_building().complete() | ||
|
|
||
| # Upload and launch | ||
| graph.upload(stream) | ||
| graph.launch(stream) | ||
| stream.sync() | ||
|
|
||
| # Check the result. | ||
| expected_buffer = make_scratch_buffer(device, 2, NBYTES) | ||
| compare_buffer = make_scratch_buffer(device, 0, NBYTES) | ||
| compare_buffer.copy_from(target, stream=stream) | ||
| stream.sync() | ||
| assert compare_equal_buffers(expected_buffer, compare_buffer) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic:
mr.handleproperty returns adriver.CUmemoryPoolobject (line 982), notNone. The check should bemr._mempool_handle != NULLor use theis_using_poolproperty.