Skip to content

Commit 246e8a1

Browse files
committed
Use SynchronousMemoryResource if memory pools are not supported
1 parent 2699ff1 commit 246e8a1

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

cuda_core/tests/test_launcher.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from conftest import skipif_need_cuda_headers
1212

13+
from cuda.bindings import driver
1314
from cuda.core.experimental import (
1415
Device,
1516
DeviceMemoryResource,
@@ -19,6 +20,8 @@
1920
ProgramOptions,
2021
launch,
2122
)
23+
from cuda.core.experimental._memory import _SynchronousMemoryResource
24+
from cuda.core.experimental._utils.cuda_utils import handle_return
2225

2326

2427
def test_launch_config_init(init_cuda):
@@ -211,7 +214,7 @@ def test_cooperative_launch():
211214
@pytest.mark.parametrize(
212215
"memory_resource_class",
213216
[
214-
DeviceMemoryResource,
217+
"device_memory_resource", # kludgy, but can go away after #726 is resolved
215218
pytest.param(
216219
LegacyPinnedMemoryResource,
217220
marks=pytest.mark.skipif(
@@ -249,9 +252,18 @@ def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_reso
249252
kernel = mod.get_kernel("memory_ops")
250253

251254
# Create memory resource
252-
if memory_resource_class == DeviceMemoryResource:
253-
mr = memory_resource_class(dev.device_id)
254-
else: # LegacyPinnedMemoryResource
255+
if memory_resource_class == "device_memory_resource":
256+
if (
257+
handle_return(
258+
driver.cuDeviceGetAttribute(
259+
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev.device_id
260+
)
261+
)
262+
) == 1:
263+
mr = DeviceMemoryResource(dev.device_id)
264+
else:
265+
mr = _SynchronousMemoryResource(dev.device_id)
266+
else:
255267
mr = memory_resource_class()
256268

257269
# Allocate memory

0 commit comments

Comments
 (0)