Skip to content

Commit

Permalink
[Misc] Add pynccl wrappers for all_gather and reduce_scatter (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#9432)

Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
tlrmchlsmth authored and mfournioux committed Nov 28, 2024
1 parent 18b6856 commit 5fb7c3e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
assert a.mean().cpu().item() == pynccl_comm.world_size**1


@worker_fn_wrapper
def all_gather_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
result = torch.zeros(num_elems * world_size,
dtype=torch.float32,
device=device)

expected = torch.cat([
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]).to(device)

with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
assert (num_elems % world_size == 0)
result = torch.zeros(num_elems // world_size,
dtype=torch.float32,
device=device)

# Calculate expected result for this rank's chunk
scattered_size = num_elems // world_size
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device)

with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
Expand Down
42 changes: 42 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,48 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))

def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
Expand Down
44 changes: 44 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Expand Down Expand Up @@ -258,6 +280,28 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
datatype, op, comm,
stream))

def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))

def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))

def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
Expand Down

0 comments on commit 5fb7c3e

Please sign in to comment.