@@ -219,15 +219,6 @@ def test_reduce_scatter(self, device, dtype):
219219 self .assertEqual (outputs [i ], expected [i ])
220220
221221
222- # Decorator
223- def requires_nccl_backend_for_symmem ():
224- return skip_but_pass_in_sandcastle_if (
225- not symm_mem .is_nccl_symmem_available (),
226- "test_nccl requires at least NCCL 2.28, skipping tests" ,
227- )
228-
229-
230- @requires_nccl_backend_for_symmem ()
231222@requires_cuda_p2p_access ()
232223class NCCLSymmetricMemoryTest (MultiProcContinuousTest ):
233224 @property
@@ -259,108 +250,6 @@ def foo():
259250 out = symm_mem .empty (numel , dtype = dtype , device = self .device )
260251 symm_mem .rendezvous (out , group = group_name )
261252
262- @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
263- @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
264- @skip_if_lt_x_gpu (2 )
265- def test_nccl_symmem_collective (self ):
266- symm_mem .set_backend ("NCCL" )
267- torch .cuda .set_device (self .rank )
268- # Need this all_reduce to initialize NCCL communicator. Otherwise, the
269- # test will hang. TODO: investigate how NCCLSymmetricMemory can
270- # initialize NCCL communicator.
271- c10d .all_reduce (torch .ones (1 , device = self .device ))
272- group_name = c10d .group .WORLD .group_name
273- symm_mem .enable_symm_mem_for_group (group_name )
274-
275- dtype = torch .float
276- numel = 1024
277-
278- out = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
279- symm_mem .rendezvous (out , group = group_name )
280- c10d .all_reduce (out )
281- torch .cuda .synchronize ()
282- self .assertEqual (
283- out , torch .full_like (out , (self .world_size - 1 ) * self .world_size / 2 )
284- )
285-
286- inp = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
287- symm_mem .rendezvous (inp , group = group_name )
288- res = torch .ops .symm_mem .one_shot_all_reduce (inp , "sum" , group_name )
289- self .assertEqual (out , res )
290-
291- @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
292- @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
293- @skip_if_lt_x_gpu (2 )
294- def test_nccl_symmem_put (self ):
295- symm_mem .set_backend ("NCCL" )
296- torch .cuda .set_device (self .rank )
297- # Need this all_reduce to initialize NCCL communicator. Otherwise, the
298- # test will hang. TODO: investigate how NCCLSymmetricMemory can
299- # initialize NCCL communicator.
300- c10d .all_reduce (torch .ones (1 , device = self .device ))
301- group_name = c10d .group .WORLD .group_name
302- symm_mem .enable_symm_mem_for_group (group_name )
303-
304- dtype = torch .float
305- numel = 1024
306- tensor = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
307- # This is needed to make sure we don't get blocked the second time we call rendezvous
308- # for the same tensor because it will be cached by that moment.
309- symm_mem .rendezvous (tensor , group = group_name )
310- signal_val = 5
311- c10d .barrier ()
312-
313- if self .rank == 1 :
314- torch .ops .symm_mem .nccl_put_with_signal (tensor , signal_val , 0 )
315- elif self .rank == 0 :
316- torch .ops .symm_mem .nccl_wait_for_signal (tensor , signal_val )
317- torch .testing .assert_close (
318- tensor , torch .ones (numel , dtype = dtype , device = self .device )
319- )
320- c10d .barrier ()
321- if self .rank == 1 :
322- tensor *= 2
323- torch .ops .symm_mem .nccl_put (tensor , 0 )
324- c10d .barrier ()
325- else :
326- c10d .barrier ()
327- if self .rank == 0 :
328- torch .testing .assert_close (
329- tensor , torch .ones (numel , dtype = dtype , device = self .device ) * 2
330- )
331-
332- @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
333- @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
334- @skip_if_lt_x_gpu (2 )
335- def test_nccl_symmem_get (self ):
336- symm_mem .set_backend ("NCCL" )
337- torch .cuda .set_device (self .rank )
338- # Need this all_reduce to initialize NCCL communicator. Otherwise, the
339- # test will hang. TODO: investigate how NCCLSymmetricMemory can
340- # initialize NCCL communicator.
341- c10d .all_reduce (torch .ones (1 , device = self .device ))
342- group_name = c10d .group .WORLD .group_name
343- symm_mem .enable_symm_mem_for_group (group_name )
344-
345- dtype = torch .float
346- numel = 1024
347- tensor = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
348- # This is needed to make sure we don't get blocked the second time we call rendezvous
349- # for the same tensor because it will be cached by that moment.
350- symm_mem .rendezvous (tensor , group = group_name )
351- c10d .barrier ()
352- if self .rank == 0 :
353- torch .ops .symm_mem .nccl_get (tensor , 1 )
354- # TODO: remove after we have wait_signal
355- c10d .barrier ()
356- torch .testing .assert_close (
357- tensor , torch .ones (numel , dtype = dtype , device = self .device )
358- )
359- else :
360- # handle.wait_signal(src_rank=0)
361- # TODO: remove after we have wait_signal
362- c10d .barrier ()
363-
364253
365254instantiate_device_type_tests (TestNCCL , globals (), only_for = "cuda" )
366255
0 commit comments