[core][distributed] add pynccl broadcast (#10843)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
a4cf256159
commit
21fe7b481a
@ -61,6 +61,7 @@ def worker_fn():
|
|||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == pynccl_comm.world_size
|
assert result == pynccl_comm.world_size
|
||||||
|
|
||||||
@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn():
|
|||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 4
|
assert result == 4
|
||||||
else:
|
else:
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn():
|
|||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 4
|
assert result == 4
|
||||||
else:
|
else:
|
||||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
@ -141,9 +146,9 @@ def worker_fn_with_cudagraph():
|
|||||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||||
enable=True):
|
enable=True):
|
||||||
a_out = pynccl_comm.all_reduce(a)
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
pynccl_comm.stream.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
pynccl_comm.stream.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||||
|
|
||||||
|
|
||||||
@ -170,6 +175,7 @@ def all_gather_worker_fn():
|
|||||||
|
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
pynccl_comm.all_gather(result, tensor)
|
pynccl_comm.all_gather(result, tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@ -207,6 +213,7 @@ def reduce_scatter_worker_fn():
|
|||||||
|
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
pynccl_comm.reduce_scatter(result, tensor)
|
pynccl_comm.reduce_scatter(result, tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
@ -241,6 +248,7 @@ def send_recv_worker_fn():
|
|||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor,
|
||||||
src=(pynccl_comm.rank - 1) %
|
src=(pynccl_comm.rank - 1) %
|
||||||
pynccl_comm.world_size)
|
pynccl_comm.world_size)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
|
||||||
@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn():
|
|||||||
pynccl_comm.recv(tensor,
|
pynccl_comm.recv(tensor,
|
||||||
src=(pynccl_comm.rank - 1) %
|
src=(pynccl_comm.rank - 1) %
|
||||||
pynccl_comm.world_size)
|
pynccl_comm.world_size)
|
||||||
|
torch.cuda.synchronize()
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
if torch.distributed.get_rank() in [0, 2]:
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
assert result == 1
|
assert result == 1
|
||||||
@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv():
|
|||||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
|
def test_pynccl_broadcast():
|
||||||
|
distributed_run(broadcast_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def broadcast_worker_fn():
|
||||||
|
# Test broadcast for every root rank.
|
||||||
|
# Essentially this is an all-gather operation.
|
||||||
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
|
recv_tensors = [
|
||||||
|
torch.empty(16,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=pynccl_comm.device)
|
||||||
|
for i in range(pynccl_comm.world_size)
|
||||||
|
]
|
||||||
|
recv_tensors[pynccl_comm.rank] = torch.ones(
|
||||||
|
16, 1024, 1024, dtype=torch.float32,
|
||||||
|
device=pynccl_comm.device) * pynccl_comm.rank
|
||||||
|
|
||||||
|
for i in range(pynccl_comm.world_size):
|
||||||
|
pynccl_comm.broadcast(recv_tensors[i], src=i)
|
||||||
|
# the broadcast op might be launched in a different stream
|
||||||
|
# need to synchronize to make sure the tensor is ready
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
assert torch.all(recv_tensors[i] == i).cpu().item()
|
||||||
|
|
||||||
|
|
||||||
def test_ncclGetUniqueId():
|
def test_ncclGetUniqueId():
|
||||||
lib = NCCLLibrary()
|
lib = NCCLLibrary()
|
||||||
unique_id = lib.ncclGetUniqueId()
|
unique_id = lib.ncclGetUniqueId()
|
||||||
|
@ -197,6 +197,25 @@ class PyNcclCommunicator:
|
|||||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
assert tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
if src == self.rank:
|
||||||
|
sendbuff = buffer_type(tensor.data_ptr())
|
||||||
|
# NCCL requires the sender also to have a receive buffer
|
||||||
|
recvbuff = buffer_type(tensor.data_ptr())
|
||||||
|
else:
|
||||||
|
sendbuff = buffer_type()
|
||||||
|
recvbuff = buffer_type(tensor.data_ptr())
|
||||||
|
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_state(self,
|
def change_state(self,
|
||||||
enable: Optional[bool] = None,
|
enable: Optional[bool] = None,
|
||||||
|
@ -189,6 +189,15 @@ class NCCLLibrary:
|
|||||||
ncclComm_t, cudaStream_t
|
ncclComm_t, cudaStream_t
|
||||||
]),
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclBroadcast(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
Function("ncclBroadcast", ncclResult_t, [
|
||||||
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||||
|
ctypes.c_int, ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
# be cautious! this is a collective call, it will block until all
|
# be cautious! this is a collective call, it will block until all
|
||||||
# processes in the communicator have called this function.
|
# processes in the communicator have called this function.
|
||||||
# because Python object destruction can happen in random order,
|
# because Python object destruction can happen in random order,
|
||||||
@ -312,6 +321,13 @@ class NCCLLibrary:
|
|||||||
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||||
comm, stream))
|
comm, stream))
|
||||||
|
|
||||||
|
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
|
count: int, datatype: int, root: int, comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
|
||||||
|
datatype, root, comm,
|
||||||
|
stream))
|
||||||
|
|
||||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user