Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add pynccl wrappers for all_gather and reduce_scatter #9432

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
assert a.mean().cpu().item() == pynccl_comm.world_size**1


@worker_fn_wrapper
def all_gather_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
result = torch.zeros(num_elems * world_size,
dtype=torch.float32,
device=device)

expected = torch.cat([
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]).to(device)

with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)

rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'

num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
assert (num_elems % world_size == 0)
result = torch.zeros(num_elems // world_size,
dtype=torch.float32,
device=device)

# Calculate expected result for this rank's chunk
scattered_size = num_elems // world_size
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device)

with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
Expand Down
42 changes: 42 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,48 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))

def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
Expand Down
44 changes: 44 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Expand Down Expand Up @@ -258,6 +280,28 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
datatype, op, comm,
stream))

def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))

def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))

def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
Expand Down