Skip to content

Ensure correct handling of buffers allocated with LegacyPinnedMemoryResource.allocate as kernel parameters #717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 26, 2025
Merged
8 changes: 7 additions & 1 deletion cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,13 @@ cdef class ParamHolder:
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
# we need the address of where the actual buffer address is stored
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
if isinstance(arg.handle, int):
Copy link
Contributor Author

@shwina shwina Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we stomach the cost of an isinstance check here?

  • One alternative is to use a try..except, where entering the try block is cheap, but entering the except block is expensive.

  • Another alternative, which will eliminate the need to make any changes to the kernel arg handling logic here:

    • introduce a new type HostPtr which wraps an integer representing a pointer, and exposes a getPtr() method to get it.
    • Expand the return type of Buffer.handle to DevicePtrT | HostPtr
    • Change LegacyPinnedMemoryResource to return a buffer whose handle is a HostPtr.

Copy link
Member

@leofang leofang Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think isinstance in Cython is cheap and what you have here is good. I don't want to introduce more types than needed, partly because we want MR providers to focus on the MR properties (is_host_accessible etc), which is nicer for programmatic checks. I actually think that Buffer.handle should be of Any type so as to not get in the way of the MR providers. From both CUDA and cccl-rt perspectives they should be all void*. We don't want to encode the memory space information as part of the type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think that Buffer.handle should be of Any type so as to not get in the way of the MR providers.

If we did type it as Any, how would _kernel_arg_handler know how to grab the pointer from underneath the Buffer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well Python does not care about type annotations, right? 🙂

Copy link
Contributor Author

@shwina shwina Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern wasn't so much about the type annotation, but more that the kernel handler won't know what to do with a Buffer whose .handle is any arbitrary type.

Prior to this PR it could only handle the case when .handle is a CUdeviceptr, or something that has a .getPtr() method.

if isinstance(arg, Buffer):
# we need the address of where the actual buffer address is stored
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())

This PR adds the ability to handle int.

Technically, .handle is also allowed to be None:

DevicePointerT = Union[driver.CUdeviceptr, int, None]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I see, you meant the mini dispatcher here needs to enumerate all possible types.

Let me think about it. What you have is good and a generic treatment can follow later.

Most likely with #564 we could rewrite the dispatcher that looks like this

if isinstance(arg, Buffer):
    prepare_arg[intptr_t](self.data, self.data_addresses, get_cuda_native_handle(arg.handle), i)

On the MR provider side, we just need them to implement a protocol

class IsHandleT(Protocol):
    def __int__(self) -> int: ...

if they are not using generic cuda.bindings or Python types. (FWIW we already have IsStreamT.) So maybe eventually Buffer.handle can be typed as

DevicePointerT = Optional[Union[IsHandleT, int]] 

Copy link
Contributor Author

@shwina shwina Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like a reasonable approach and I agree it would simplify the handling here. A couple of comments:

  • Perhaps we should rename DevicePointerT to just PointerT? In the case of pinned memory for instance, it doesn't actually represent a device pointer AFAIU.
  • If we use the protocol as written, then Union[IsHandleT, int] is equivalent to just IsHandleT (int type implements __int__). The protocol would also allow types like float or bool.
    • I feel like this discussion has been had before, but it might be worth considering a protocol with a __cuda_handle__() method or something, rather than __int__()

# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
# Here's the dilemma: We want to have a fast path to pass in Python
Expand Down
137 changes: 137 additions & 0 deletions cuda_core/examples/memory_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0

# ################################################################################
#
# This demo illustrates:
#
# 1. How to use different memory resources to allocate and manage memory
# 2. How to copy data between different memory types
# 3. How to use DLPack to interoperate with other libraries
#
# ################################################################################

import sys

import cupy as cp
import numpy as np

from cuda.core.experimental import (
Device,
LaunchConfig,
LegacyPinnedMemoryResource,
Program,
ProgramOptions,
launch,
)

if np.__version__ < "2.1.0":
print("This example requires NumPy 2.1.0 or later", file=sys.stderr)
sys.exit(0)

# Kernel for memory operations
code = """
extern "C"
__global__ void memory_ops(float* device_data,
float* pinned_data,
size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
// Access device memory
device_data[tid] = device_data[tid] + 1.0f;

// Access pinned memory (zero-copy from GPU)
pinned_data[tid] = pinned_data[tid] * 3.0f;
}
}
"""

dev = Device()
dev.set_current()
stream = dev.create_stream()
# tell CuPy to use our stream as the current stream:
cp.cuda.ExternalStream(int(stream.handle)).use()

# Compile kernel
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin")
kernel = mod.get_kernel("memory_ops")

# Create different memory resources
device_mr = dev.memory_resource
pinned_mr = LegacyPinnedMemoryResource()

# Allocate different types of memory
size = 1024
dtype = cp.float32
element_size = dtype().itemsize
total_size = size * element_size

# 1. Device Memory (GPU-only)
device_buffer = device_mr.allocate(total_size, stream=stream)
device_array = cp.from_dlpack(device_buffer).view(dtype=dtype)

# 2. Pinned Memory (CPU memory, GPU accessible)
pinned_buffer = pinned_mr.allocate(total_size, stream=stream)
pinned_array = np.from_dlpack(pinned_buffer).view(dtype=dtype)

# Initialize data
rng = cp.random.default_rng()
device_array[:] = rng.random(size, dtype=dtype)
pinned_array[:] = rng.random(size, dtype=dtype).get()

# Store original values for verification
device_original = device_array.copy()
pinned_original = pinned_array.copy()

# Sync before kernel launch
stream.sync()

# Launch kernel
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)

launch(stream, config, kernel, device_buffer, pinned_buffer, cp.uint64(size))
stream.sync()

# Verify kernel operations
assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed"
assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed"

# Copy data between different memory types
print("\nCopying data between memory types...")

# Copy from device to pinned memory
device_buffer.copy_to(pinned_buffer, stream=stream)
stream.sync()

# Verify the copy operation
assert cp.allclose(pinned_array, device_array), "Device to pinned copy failed"

# Create a new device buffer and copy from pinned
new_device_buffer = device_mr.allocate(total_size, stream=stream)
new_device_array = cp.from_dlpack(new_device_buffer).view(dtype=dtype)

pinned_buffer.copy_to(new_device_buffer, stream=stream)
stream.sync()

# Verify the copy operation
assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed"

# Clean up
device_buffer.close(stream)
pinned_buffer.close(stream)
new_device_buffer.close(stream)
stream.close()
cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream

# Verify buffers are properly closed
assert device_buffer.handle == 0, "Device buffer should be closed"
assert pinned_buffer.handle == 0, "Pinned buffer should be closed"
assert new_device_buffer.handle == 0, "New device buffer should be closed"

print("Memory management example completed!")
111 changes: 110 additions & 1 deletion cuda_core/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
import os
import pathlib

import cupy as cp
import numpy as np
import pytest
from conftest import skipif_need_cuda_headers

from cuda.core.experimental import Device, LaunchConfig, LegacyPinnedMemoryResource, Program, ProgramOptions, launch
from cuda.core.experimental import (
Device,
DeviceMemoryResource,
LaunchConfig,
LegacyPinnedMemoryResource,
Program,
ProgramOptions,
launch,
)
from cuda.core.experimental._memory import _SynchronousMemoryResource


def test_launch_config_init(init_cuda):
Expand Down Expand Up @@ -197,3 +207,102 @@ def test_cooperative_launch():
config = LaunchConfig(grid=1, block=1, cooperative_launch=True)
launch(s, config, ker)
s.sync()


@pytest.mark.parametrize(
"memory_resource_class",
[
"device_memory_resource", # kludgy, but can go away after #726 is resolved
pytest.param(
LegacyPinnedMemoryResource,
marks=pytest.mark.skipif(
tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5),
reason="need numpy 2.2.5+, numpy GH #28632",
),
),
],
)
def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_resource_class):
"""Test that kernels can access memory allocated by memory resources."""
dev = Device()
dev.set_current()
stream = dev.create_stream()
# tell CuPy to use our stream as the current stream:
cp.cuda.ExternalStream(int(stream.handle)).use()

# Kernel that operates on memory
code = """
extern "C"
__global__ void memory_ops(float* data, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
// Access memory (device or pinned)
data[tid] = data[tid] * 3.0f;
}
}
"""

# Compile kernel
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin")
kernel = mod.get_kernel("memory_ops")

# Create memory resource
if memory_resource_class == "device_memory_resource":
if dev.properties.memory_pools_supported:
mr = DeviceMemoryResource(dev.device_id)
else:
mr = _SynchronousMemoryResource(dev.device_id)
else:
mr = memory_resource_class()

# Allocate memory
size = 1024
dtype = np.float32
element_size = dtype().itemsize
total_size = size * element_size

buffer = mr.allocate(total_size, stream=stream)

# Create array view based on memory type
if mr.is_host_accessible:
# For pinned memory, use numpy
array = np.from_dlpack(buffer).view(dtype=dtype)
else:
array = cp.from_dlpack(buffer).view(dtype=dtype)

# Initialize data with random values
if mr.is_host_accessible:
rng = np.random.default_rng()
array[:] = rng.random(size, dtype=dtype)
else:
rng = cp.random.default_rng()
array[:] = rng.random(size, dtype=dtype)

# Store original values for verification
original = array.copy()

# Sync before kernel launch
stream.sync()

# Launch kernel
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)

launch(stream, config, kernel, buffer, np.uint64(size))
stream.sync()

# Verify kernel operations
assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed"

# Clean up
buffer.close(stream)
stream.close()

cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream

# Verify buffer is properly closed
assert buffer.handle == 0, f"{memory_resource_class.__name__} buffer should be closed"
Loading