Skip to content

Commit

Permalink
[core][distributed] add pynccl broadcast (#10843)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Dec 3, 2024
1 parent a4cf256 commit 21fe7b4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
45 changes: 43 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def worker_fn():
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == pynccl_comm.world_size

Expand All @@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2

Expand All @@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn():
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 4
else:
tensor = tensor_model_parallel_all_reduce(tensor)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 2

Expand All @@ -141,9 +146,9 @@ def worker_fn_with_cudagraph():
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
a_out = pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
torch.cuda.synchronize()
graph.replay()
pynccl_comm.stream.synchronize()
torch.cuda.synchronize()
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1


Expand All @@ -170,6 +175,7 @@ def all_gather_worker_fn():

with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


Expand Down Expand Up @@ -207,6 +213,7 @@ def reduce_scatter_worker_fn():

with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


Expand Down Expand Up @@ -241,6 +248,7 @@ def send_recv_worker_fn():
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
assert result == 1

Expand Down Expand Up @@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn():
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1
Expand All @@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv():
distributed_run(multiple_send_recv_worker_fn, 4)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_broadcast():
distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
# Test broadcast for every root rank.
# Essentially this is an all-gather operation.
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
recv_tensors = [
torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=pynccl_comm.device)
for i in range(pynccl_comm.world_size)
]
recv_tensors[pynccl_comm.rank] = torch.ones(
16, 1024, 1024, dtype=torch.float32,
device=pynccl_comm.device) * pynccl_comm.rank

for i in range(pynccl_comm.world_size):
pynccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.cuda.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()


def test_ncclGetUniqueId():
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()
Expand Down
19 changes: 19 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,25 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

@contextmanager
def change_state(self,
enable: Optional[bool] = None,
Expand Down
16 changes: 16 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ class NCCLLibrary:
ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function("ncclBroadcast", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ctypes.c_int, ncclComm_t, cudaStream_t
]),

# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
Expand Down Expand Up @@ -312,6 +321,13 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))

def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
datatype, root, comm,
stream))

def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

Expand Down

0 comments on commit 21fe7b4

Please sign in to comment.