Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 3, 2024
1 parent 5db579c commit 6873cdb
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,25 @@ def test_pynccl_broadcast():

@worker_fn_wrapper
def broadcast_worker_fn():
# Test broadcast for every root rank.
# Essentially this is an all-gather operation.
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(
pynccl_comm.rank) * pynccl_comm.rank

recv = torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=pynccl_comm.device)
recv_tensors = [
torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=pynccl_comm.device)
for i in range(pynccl_comm.world_size)
]
recv_tensors[pynccl_comm.rank] = torch.ones(
16, 1024, 1024, dtype=torch.float32,
device=pynccl_comm.device) * pynccl_comm.rank

for i in range(pynccl_comm.world_size):
pynccl_comm.broadcast(recv if i != pynccl_comm.rank else tensor, src=i)
result = recv.mean().cpu().item()
pynccl_comm.broadcast(recv_tensors[i], src=i)
result = recv_tensors[i].mean().cpu().item()
assert result == i


Expand Down

0 comments on commit 6873cdb

Please sign in to comment.