11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- from conftest import skipif_need_cuda_headers
5- from cuda .core .experimental import Device , DeviceMemoryResource , DeviceMemoryResourceOptions , EventOptions
6- from helpers .buffers import make_scratch_buffer , compare_equal_buffers
7- from helpers .latch import LatchKernel
8- from helpers .logging import TimestampedLogger
9- import ctypes
104import multiprocessing as mp
5+
116import pytest
12- import time
7+ from cuda .core .experimental import Device , EventOptions
8+ from helpers .buffers import compare_equal_buffers , make_scratch_buffer
9+ from helpers .latch import LatchKernel
10+ from helpers .logging import TimestampedLogger
1311
1412ENABLE_LOGGING = False # Set True for test debugging and development
1513CHILD_TIMEOUT_SEC = 20
1614NBYTES = 64
1715
16+
1817class TestEventIpc :
1918 """Check the basic usage of IPC-enabled events with a latch kernel."""
2019
21- @skipif_need_cuda_headers # libcu++
2220 def test_main (self , ipc_device , ipc_memory_resource ):
2321 log = TimestampedLogger (prefix = "parent: " , enabled = ENABLE_LOGGING )
2422 device = ipc_device
2523 mr = ipc_memory_resource
2624 stream1 = device .create_stream ()
25+ latch = LatchKernel (device )
2726
2827 # Start the child process.
2928 q_out , q_in = [mp .Queue () for _ in range (2 )]
@@ -41,7 +40,6 @@ def test_main(self, ipc_device, ipc_memory_resource):
4140 q_out .put (buffer )
4241
4342 # Stream 1:
44- latch = LatchKernel (device )
4543 log ("enqueuing latch kernel on stream1" )
4644 latch .launch (stream1 )
4745 log ("enqueuing copy on stream1" )
@@ -69,7 +67,6 @@ def test_main(self, ipc_device, ipc_memory_resource):
6967 stream1 .sync ()
7068 assert compare_equal_buffers (target , twos )
7169
72-
7370 def child_main (self , log , q_in , q_out ):
7471 log .prefix = " child: "
7572 log ("ready" )
@@ -99,13 +96,15 @@ def test_event_is_monadic(ipc_device):
9996
10097 stream = device .create_stream ()
10198 e = stream .record (options = {"ipc_enabled" : True })
102- with pytest .raises (TypeError , match = r"^IPC-enabled events should not be re-recorded, instead create a new event by supplying options\.$" ):
99+ with pytest .raises (
100+ TypeError ,
101+ match = r"^IPC-enabled events should not be re-recorded, instead create a new event by supplying options\.$" ,
102+ ):
103103 stream .record (e )
104104
105105
106106@pytest .mark .parametrize (
107- "options" , [ {"ipc_enabled" : True , "enable_timing" : True },
108- EventOptions (ipc_enabled = True , enable_timing = True )]
107+ "options" , [{"ipc_enabled" : True , "enable_timing" : True }, EventOptions (ipc_enabled = True , enable_timing = True )]
109108)
110109def test_event_timing_disabled (ipc_device , options ):
111110 """Check that IPC-enabled events cannot be created with timing enabled."""
@@ -114,11 +113,13 @@ def test_event_timing_disabled(ipc_device, options):
114113 with pytest .raises (TypeError , match = r"^IPC-enabled events cannot use timing\.$" ):
115114 stream .record (options = options )
116115
116+
117117class TestIpcEventProperties :
118118 """
119119 Check that event properties are properly set after transfer to a child
120120 process.
121121 """
122+
122123 @pytest .mark .parametrize ("busy_waited_sync" , [True , False ])
123124 @pytest .mark .parametrize ("use_options_cls" , [True , False ])
124125 @pytest .mark .parametrize ("use_option_kw" , [True , False ])
@@ -132,13 +133,12 @@ def test_main(self, ipc_device, busy_waited_sync, use_options_cls, use_option_kw
132133 process .start ()
133134
134135 # Create an event and send it.
135- options = \
136- EventOptions (ipc_enabled = True , busy_waited_sync = busy_waited_sync ) \
137- if use_options_cls else \
138- {"ipc_enabled" : True , "busy_waited_sync" : busy_waited_sync }
139- e = stream .record (options = options ) \
140- if use_option_kw else \
141- stream .record (None , options )
136+ options = (
137+ EventOptions (ipc_enabled = True , busy_waited_sync = busy_waited_sync )
138+ if use_options_cls
139+ else {"ipc_enabled" : True , "busy_waited_sync" : busy_waited_sync }
140+ )
141+ e = stream .record (options = options ) if use_option_kw else stream .record (None , options )
142142 q_out .put (e )
143143
144144 # Check its properties.
@@ -156,28 +156,17 @@ def test_main(self, ipc_device, busy_waited_sync, use_options_cls, use_option_kw
156156 def child_main (self , q_in , q_out ):
157157 device = Device ()
158158 device .set_current ()
159- stream = device .create_stream ()
160159
161160 # Get the event.
162161 e = q_in .get (timeout = CHILD_TIMEOUT_SEC )
163162
164163 # Send its properties.
165- props = (e .get_ipc_descriptor (),
166- e .is_ipc_enabled ,
167- e .is_timing_disabled ,
168- e .is_sync_busy_waited ,
169- e .device ,
170- e .context ,)
164+ props = (
165+ e .get_ipc_descriptor (),
166+ e .is_ipc_enabled ,
167+ e .is_timing_disabled ,
168+ e .is_sync_busy_waited ,
169+ e .device ,
170+ e .context ,
171+ )
171172 q_out .put (props )
172-
173-
174-
175- # TODO: daisy chain processes
176-
177- if __name__ == "__main__" :
178- mp .set_start_method ("spawn" )
179- device = Device ()
180- device .set_current ()
181- TestIpcEventWithLatch ().test_main (device )
182-
183-
0 commit comments