From 21fe7b481a3a84dc9ebe2497ec89a17002ad52c5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 2 Dec 2024 20:53:23 -0800 Subject: [PATCH] [core][distributed] add pynccl broadcast (#10843) Signed-off-by: youkaichao --- tests/distributed/test_pynccl.py | 45 ++++++++++++++++++- .../device_communicators/pynccl.py | 19 ++++++++ .../device_communicators/pynccl_wrapper.py | 16 +++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index fb24d6bc2c100..4e27babf12cc3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -61,6 +61,7 @@ def worker_fn(): dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 4 else: tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 2 @@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 4 else: tensor = tensor_model_parallel_all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 2 @@ -141,9 +146,9 @@ def worker_fn_with_cudagraph(): graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): a_out = pynccl_comm.all_reduce(a) - pynccl_comm.stream.synchronize() + torch.cuda.synchronize() graph.replay() - pynccl_comm.stream.synchronize() + torch.cuda.synchronize() assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 @@ -170,6 +175,7 @@ def all_gather_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -207,6 +213,7 @@ def reduce_scatter_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -241,6 +248,7 @@ def send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 1 @@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + torch.cuda.synchronize() result = tensor.mean().cpu().item() if torch.distributed.get_rank() in [0, 2]: assert result == 1 @@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_broadcast(): + distributed_run(broadcast_worker_fn, 4) + + +@worker_fn_wrapper +def broadcast_worker_fn(): + # Test broadcast for every root rank. + # Essentially this is an all-gather operation. + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + recv_tensors = [ + torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=pynccl_comm.device) + for i in range(pynccl_comm.world_size) + ] + recv_tensors[pynccl_comm.rank] = torch.ones( + 16, 1024, 1024, dtype=torch.float32, + device=pynccl_comm.device) * pynccl_comm.rank + + for i in range(pynccl_comm.world_size): + pynccl_comm.broadcast(recv_tensors[i], src=i) + # the broadcast op might be launched in a different stream + # need to synchronize to make sure the tensor is ready + torch.cuda.synchronize() + assert torch.all(recv_tensors[i] == i).cpu().item() + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index d4e3f81747038..a6800f93f167b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -197,6 +197,25 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index ff88f72470b27..7dea61b6a09f1 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -189,6 +189,15 @@ class NCCLLibrary: ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -312,6 +321,13 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))