[core][distributed] add pynccl broadcast (#10843)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
a4cf256159
commit
21fe7b481a
@ -61,6 +61,7 @@ def worker_fn():
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == pynccl_comm.world_size
|
||||
|
||||
@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn():
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn():
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
@ -141,9 +146,9 @@ def worker_fn_with_cudagraph():
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
pynccl_comm.stream.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
pynccl_comm.stream.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@ -170,6 +175,7 @@ def all_gather_worker_fn():
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@ -207,6 +213,7 @@ def reduce_scatter_worker_fn():
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@ -241,6 +248,7 @@ def send_recv_worker_fn():
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 1
|
||||
|
||||
@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn():
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
result = tensor.mean().cpu().item()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert result == 1
|
||||
@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv():
|
||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_broadcast():
|
||||
distributed_run(broadcast_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def broadcast_worker_fn():
|
||||
# Test broadcast for every root rank.
|
||||
# Essentially this is an all-gather operation.
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
recv_tensors = [
|
||||
torch.empty(16,
|
||||
1024,
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=pynccl_comm.device)
|
||||
for i in range(pynccl_comm.world_size)
|
||||
]
|
||||
recv_tensors[pynccl_comm.rank] = torch.ones(
|
||||
16, 1024, 1024, dtype=torch.float32,
|
||||
device=pynccl_comm.device) * pynccl_comm.rank
|
||||
|
||||
for i in range(pynccl_comm.world_size):
|
||||
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
||||
# the broadcast op might be launched in a different stream
|
||||
# need to synchronize to make sure the tensor is ready
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(recv_tensors[i] == i).cpu().item()
|
||||
|
||||
|
||||
def test_ncclGetUniqueId():
|
||||
lib = NCCLLibrary()
|
||||
unique_id = lib.ncclGetUniqueId()
|
||||
|
@ -197,6 +197,25 @@ class PyNcclCommunicator:
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self,
|
||||
enable: Optional[bool] = None,
|
||||
|
@ -189,6 +189,15 @@ class NCCLLibrary:
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function("ncclBroadcast", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ctypes.c_int, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
@ -312,6 +321,13 @@ class NCCLLibrary:
|
||||
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
|
||||
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
|
||||
datatype, root, comm,
|
||||
stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user