From 5fb7c3ed25fb99d2db4d78b8d6edc5e4dfd4ea0b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 22 Nov 2024 22:14:03 -0500 Subject: [PATCH] [Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432) Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- tests/distributed/test_pynccl.py | 69 +++++++++++++++++++ .../device_communicators/pynccl.py | 42 +++++++++++ .../device_communicators/pynccl_wrapper.py | 44 ++++++++++++ 3 files changed, 155 insertions(+) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e0e424439e3a5..f702d7c46ea73 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -150,6 +150,75 @@ def worker_fn_with_cudagraph(): assert a.mean().cpu().item() == pynccl_comm.world_size**1 +@worker_fn_wrapper +def all_gather_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + result = torch.zeros(num_elems * world_size, + dtype=torch.float32, + device=device) + + expected = torch.cat([ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ]).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gather(): + distributed_run(all_gather_worker_fn, 2) + + +@worker_fn_wrapper +def reduce_scatter_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + assert (num_elems % world_size == 0) + result = torch.zeros(num_elems // world_size, + dtype=torch.float32, + device=device) + + # Calculate expected result for this rank's chunk + scattered_size = num_elems // world_size + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] + for tensor in all_tensors).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatter(): + distributed_run(reduce_scatter_worker_fn, 2) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7c6f48e88637b..7411304eb18fa 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -131,6 +131,48 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 7619c98f22148..ff88f72470b27 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,28 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -258,6 +280,28 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,