Skip to content

Commit f98a286

Browse files
committed
Simplify LatchKernel and fix formatting.
1 parent 02975e9 commit f98a286

File tree

8 files changed

+53
-70
lines changed

8 files changed

+53
-70
lines changed

cuda_core/tests/helpers/buffers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def verify_buffer(self, buffer, seed=None, value=None):
8181
self.sync_target.sync()
8282
assert libc.memcmp(ptr_test, ptr_expected, self.size) == 0
8383

84-
8584
@staticmethod
8685
def _ptr(buffer):
8786
"""Get a pointer to the specified buffer."""
@@ -101,7 +100,7 @@ def _get_pattern_buffer(self, seed, value):
101100
pattern_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.size)
102101
ptr = self._ptr(pattern_buffer)
103102
for i in range(self.size):
104-
ptr[i] = (seed + i) & 0xFF
103+
ptr[i] = (seed + i) & 0xFF
105104
self.pattern_buffers[key] = pattern_buffer
106105
return pattern_buffer
107106

@@ -121,5 +120,3 @@ def compare_equal_buffers(buffer1, buffer2):
121120
ptr1 = ctypes.cast(int(buffer1.handle), ctypes.POINTER(ctypes.c_byte))
122121
ptr2 = ctypes.cast(int(buffer2.handle), ctypes.POINTER(ctypes.c_byte))
123122
return libc.memcmp(ptr1, ptr2, buffer1.size) == 0
124-
125-

cuda_core/tests/helpers/latch.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import ctypes
5+
6+
import pytest
47
from cuda.core.experimental import (
58
LaunchConfig,
69
LegacyPinnedMemoryResource,
710
Program,
811
ProgramOptions,
912
launch,
1013
)
14+
1115
import helpers
12-
import ctypes
16+
1317

1418
class LatchKernel:
1519
"""
16-
Manages a kernel that blocks progress until released.
20+
Manages a kernel that blocks stream progress until released.
1721
"""
1822

1923
def __init__(self, device):
24+
if helpers.CUDA_INCLUDE_PATH is None:
25+
pytest.skip("need CUDA header")
2026
code = """
2127
#include <cuda/atomic>
2228
@@ -44,26 +50,14 @@ def __init__(self, device):
4450
self.busy_wait_flag[0] = 0
4551

4652
def launch(self, stream):
53+
"""Launch the latch kernel, blocking stream progress via busy waiting."""
4754
config = LaunchConfig(grid=1, block=1)
48-
launch(stream, config, self.kernel, self.busy_wait_flag_address)
55+
launch(stream, config, self.kernel, int(self.buffer.handle))
4956

5057
def release(self):
58+
"""Release the latch, allowing stream progress."""
5159
self.busy_wait_flag[0] = 1
5260

53-
@property
54-
def busy_wait_flag_address(self):
55-
return int(self.buffer.handle)
56-
5761
@property
5862
def busy_wait_flag(self):
59-
return ctypes.cast(self.busy_wait_flag_address, ctypes.POINTER(ctypes.c_int32))
60-
61-
def close(self):
62-
buffer = getattr(self, 'buffer', None)
63-
if buffer is not None:
64-
buffer.close()
65-
66-
def __del__(self):
67-
self.close()
68-
69-
63+
return ctypes.cast(int(self.buffer.handle), ctypes.POINTER(ctypes.c_int32))

cuda_core/tests/helpers/logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import time
55

6+
67
class TimestampedLogger:
78
"""
89
A logger that prefixes each output with a timestamp, containing the elapsed

cuda_core/tests/memory_ipc/test_event_ipc.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from conftest import skipif_need_cuda_headers
5-
from cuda.core.experimental import Device, DeviceMemoryResource, DeviceMemoryResourceOptions, EventOptions
6-
from helpers.buffers import make_scratch_buffer, compare_equal_buffers
7-
from helpers.latch import LatchKernel
8-
from helpers.logging import TimestampedLogger
9-
import ctypes
104
import multiprocessing as mp
5+
116
import pytest
12-
import time
7+
from cuda.core.experimental import Device, EventOptions
8+
from helpers.buffers import compare_equal_buffers, make_scratch_buffer
9+
from helpers.latch import LatchKernel
10+
from helpers.logging import TimestampedLogger
1311

1412
ENABLE_LOGGING = False # Set True for test debugging and development
1513
CHILD_TIMEOUT_SEC = 20
1614
NBYTES = 64
1715

16+
1817
class TestEventIpc:
1918
"""Check the basic usage of IPC-enabled events with a latch kernel."""
2019

21-
@skipif_need_cuda_headers # libcu++
2220
def test_main(self, ipc_device, ipc_memory_resource):
2321
log = TimestampedLogger(prefix="parent: ", enabled=ENABLE_LOGGING)
2422
device = ipc_device
2523
mr = ipc_memory_resource
2624
stream1 = device.create_stream()
25+
latch = LatchKernel(device)
2726

2827
# Start the child process.
2928
q_out, q_in = [mp.Queue() for _ in range(2)]
@@ -41,7 +40,6 @@ def test_main(self, ipc_device, ipc_memory_resource):
4140
q_out.put(buffer)
4241

4342
# Stream 1:
44-
latch = LatchKernel(device)
4543
log("enqueuing latch kernel on stream1")
4644
latch.launch(stream1)
4745
log("enqueuing copy on stream1")
@@ -69,7 +67,6 @@ def test_main(self, ipc_device, ipc_memory_resource):
6967
stream1.sync()
7068
assert compare_equal_buffers(target, twos)
7169

72-
7370
def child_main(self, log, q_in, q_out):
7471
log.prefix = " child: "
7572
log("ready")
@@ -99,13 +96,15 @@ def test_event_is_monadic(ipc_device):
9996

10097
stream = device.create_stream()
10198
e = stream.record(options={"ipc_enabled": True})
102-
with pytest.raises(TypeError, match=r"^IPC-enabled events should not be re-recorded, instead create a new event by supplying options\.$"):
99+
with pytest.raises(
100+
TypeError,
101+
match=r"^IPC-enabled events should not be re-recorded, instead create a new event by supplying options\.$",
102+
):
103103
stream.record(e)
104104

105105

106106
@pytest.mark.parametrize(
107-
"options", [ {"ipc_enabled": True, "enable_timing": True},
108-
EventOptions(ipc_enabled=True, enable_timing=True)]
107+
"options", [{"ipc_enabled": True, "enable_timing": True}, EventOptions(ipc_enabled=True, enable_timing=True)]
109108
)
110109
def test_event_timing_disabled(ipc_device, options):
111110
"""Check that IPC-enabled events cannot be created with timing enabled."""
@@ -114,11 +113,13 @@ def test_event_timing_disabled(ipc_device, options):
114113
with pytest.raises(TypeError, match=r"^IPC-enabled events cannot use timing\.$"):
115114
stream.record(options=options)
116115

116+
117117
class TestIpcEventProperties:
118118
"""
119119
Check that event properties are properly set after transfer to a child
120120
process.
121121
"""
122+
122123
@pytest.mark.parametrize("busy_waited_sync", [True, False])
123124
@pytest.mark.parametrize("use_options_cls", [True, False])
124125
@pytest.mark.parametrize("use_option_kw", [True, False])
@@ -132,13 +133,12 @@ def test_main(self, ipc_device, busy_waited_sync, use_options_cls, use_option_kw
132133
process.start()
133134

134135
# Create an event and send it.
135-
options = \
136-
EventOptions(ipc_enabled=True, busy_waited_sync=busy_waited_sync) \
137-
if use_options_cls else \
138-
{"ipc_enabled": True, "busy_waited_sync": busy_waited_sync}
139-
e = stream.record(options=options) \
140-
if use_option_kw else \
141-
stream.record(None, options)
136+
options = (
137+
EventOptions(ipc_enabled=True, busy_waited_sync=busy_waited_sync)
138+
if use_options_cls
139+
else {"ipc_enabled": True, "busy_waited_sync": busy_waited_sync}
140+
)
141+
e = stream.record(options=options) if use_option_kw else stream.record(None, options)
142142
q_out.put(e)
143143

144144
# Check its properties.
@@ -156,28 +156,17 @@ def test_main(self, ipc_device, busy_waited_sync, use_options_cls, use_option_kw
156156
def child_main(self, q_in, q_out):
157157
device = Device()
158158
device.set_current()
159-
stream = device.create_stream()
160159

161160
# Get the event.
162161
e = q_in.get(timeout=CHILD_TIMEOUT_SEC)
163162

164163
# Send its properties.
165-
props = (e.get_ipc_descriptor(),
166-
e.is_ipc_enabled,
167-
e.is_timing_disabled,
168-
e.is_sync_busy_waited,
169-
e.device,
170-
e.context,)
164+
props = (
165+
e.get_ipc_descriptor(),
166+
e.is_ipc_enabled,
167+
e.is_timing_disabled,
168+
e.is_sync_busy_waited,
169+
e.device,
170+
e.context,
171+
)
171172
q_out.put(props)
172-
173-
174-
175-
# TODO: daisy chain processes
176-
177-
if __name__ == "__main__":
178-
mp.set_start_method("spawn")
179-
device = Device()
180-
device.set_current()
181-
TestIpcEventWithLatch().test_main(device)
182-
183-

cuda_core/tests/memory_ipc/test_memory_ipc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import multiprocessing as mp
5+
46
from cuda.core.experimental import Buffer, DeviceMemoryResource
57
from helpers.buffers import PatternGen
6-
import multiprocessing as mp
78

89
CHILD_TIMEOUT_SEC = 20
910
NBYTES = 64

cuda_core/tests/test_event.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
)
1414
from helpers.latch import LatchKernel
1515

16-
from conftest import skipif_need_cuda_headers
1716
from cuda_python_test_helpers import IS_WSL
1817

1918

@@ -115,7 +114,6 @@ def test_error_timing_recorded():
115114
event3 - event2
116115

117116

118-
@skipif_need_cuda_headers # libcu++
119117
def test_error_timing_incomplete():
120118
device = Device()
121119
device.set_current()

cuda_core/tests/test_helpers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import time
6+
7+
import pytest
58
from cuda.core.experimental import Device
9+
from helpers.buffers import PatternGen, compare_equal_buffers, make_scratch_buffer
610
from helpers.latch import LatchKernel
711
from helpers.logging import TimestampedLogger
8-
from helpers.buffers import make_scratch_buffer, compare_equal_buffers, PatternGen
9-
import time
10-
import pytest
1112

1213
ENABLE_LOGGING = False # Set True for test debugging and development
1314
NBYTES = 64
1415

16+
1517
def test_latchkernel():
1618
"""Test LatchKernel."""
1719
log = TimestampedLogger()
@@ -38,6 +40,7 @@ def test_latchkernel():
3840
assert compare_equal_buffers(target, ones)
3941
log("done")
4042

43+
4144
def test_patterngen_seeds():
4245
"""Test PatternGen with seed argument."""
4346
device = Device()
@@ -49,10 +52,11 @@ def test_patterngen_seeds():
4952
for i in range(256):
5053
pgen.fill_buffer(buffer, seed=i)
5154
pgen.verify_buffer(buffer, seed=i)
52-
for j in range(i+1, 256):
55+
for j in range(i + 1, 256):
5356
with pytest.raises(AssertionError):
5457
pgen.verify_buffer(buffer, seed=j)
5558

59+
5660
def test_patterngen_values():
5761
"""Test PatternGen with value argument, also compare_equal_buffers."""
5862
device = Device()
@@ -64,4 +68,3 @@ def test_patterngen_values():
6468
pgen = PatternGen(device, NBYTES)
6569
pgen.verify_buffer(ones, value=1)
6670
pgen.verify_buffer(twos, value=2)
67-

cuda_core/tests/test_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
np = None
1414
import ctypes
1515
import platform
16-
from helpers.buffers import DummyUnifiedMemoryResource
1716

1817
import pytest
1918
from cuda.core.experimental import (
@@ -28,6 +27,7 @@
2827
from cuda.core.experimental._memory import DLDeviceType, IPCBufferDescriptor
2928
from cuda.core.experimental._utils.cuda_utils import handle_return
3029
from cuda.core.experimental.utils import StridedMemoryView
30+
from helpers.buffers import DummyUnifiedMemoryResource
3131

3232
from cuda_python_test_helpers import IS_WSL, supports_ipc_mempool
3333

0 commit comments

Comments
 (0)