[Core][1/N] Support send/recv in PyNCCL Groups (#4988)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
2ba80bed27
commit
5eda2ea02a
@ -3,6 +3,7 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture, tensor_model_parallel_all_reduce)
|
||||
@ -68,7 +69,7 @@ def test_pynccl():
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def multiple_tp_worker_fn():
|
||||
def multiple_allreduce_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||
@ -92,14 +93,14 @@ def multiple_tp_worker_fn():
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_tp():
|
||||
def test_pynccl_multiple_allreduce():
|
||||
# this tests pynccl for multiple tp groups, in a standalone way
|
||||
# i.e. call `pynccl_comm.all_reduce` directly
|
||||
distributed_run(multiple_tp_worker_fn, 4)
|
||||
distributed_run(multiple_allreduce_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def multiple_tp_with_vllm_worker_fn():
|
||||
def multiple_allreduce_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn():
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_tp_with_vllm():
|
||||
def test_pynccl_multiple_allreduce_with_vllm():
|
||||
# this tests pynccl for multiple tp groups, together with vllm
|
||||
# i.e. call `tensor_model_parallel_all_reduce`
|
||||
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
|
||||
distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph():
|
||||
distributed_run(worker_fn_with_cudagraph, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def send_recv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator()
|
||||
if pynccl_comm.rank == 0:
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
else:
|
||||
tensor = torch.empty(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_send_recv():
|
||||
distributed_run(send_recv_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def multiple_send_recv_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
elif torch.distributed.get_rank() == 1:
|
||||
tensor = 2 * torch.ones(
|
||||
16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
else:
|
||||
tensor = torch.empty(16,
|
||||
1024,
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert result == 1
|
||||
else:
|
||||
assert result == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_send_recv():
|
||||
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||
|
||||
|
||||
def test_ncclGetUniqueId():
|
||||
lib = NCCLLibrary()
|
||||
unique_id = lib.ncclGetUniqueId()
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .parallel_state import (get_cpu_world_group,
|
||||
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -54,13 +54,19 @@ def graph_capture():
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if pynccl_comm is None:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
tp_pynccl_comm = get_tp_pynccl_communicator()
|
||||
pp_pynccl_comm = get_pp_pynccl_communicator()
|
||||
if not tp_pynccl_comm:
|
||||
maybe_tp_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
with maybe_pynccl_context:
|
||||
if not pp_pynccl_comm:
|
||||
maybe_pp_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
|
@ -126,6 +126,40 @@ class PyNcclCommunicator:
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def send(self,
|
||||
tensor: torch.Tensor,
|
||||
dst: Optional[int] = None,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
if dst is None:
|
||||
dst = (self.rank + 1) % self.world_size
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def recv(self,
|
||||
tensor: torch.Tensor,
|
||||
src: Optional[int] = None,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
if src is None:
|
||||
src = (self.rank - 1) % self.world_size
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self,
|
||||
enable: Optional[bool] = None,
|
||||
|
@ -151,6 +151,22 @@ class NCCLLibrary:
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclSend", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclRecv", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
@ -248,6 +264,16 @@ class NCCLLibrary:
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream))
|
||||
|
||||
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
@ -22,6 +22,8 @@ _TP_PYNCCL_COMMUNICATOR = None
|
||||
_TP_CA_COMMUNICATOR = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||
_PP_CPU_GROUP: Optional[ProcessGroup] = None
|
||||
_PP_PYNCCL_COMMUNICATOR = None
|
||||
|
||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
||||
# it will use this group. It is initialized with the `backend`
|
||||
@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def get_pp_pynccl_communicator():
|
||||
global _PP_PYNCCL_COMMUNICATOR
|
||||
return _PP_PYNCCL_COMMUNICATOR
|
||||
|
||||
|
||||
def get_tp_pynccl_communicator():
|
||||
global _TP_PYNCCL_COMMUNICATOR
|
||||
return _TP_PYNCCL_COMMUNICATOR
|
||||
@ -180,6 +187,7 @@ def initialize_model_parallel(
|
||||
_TP_CPU_GROUP = cpu_group
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
if tensor_model_parallel_size > 1:
|
||||
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||
group=_TP_CPU_GROUP,
|
||||
device=_LOCAL_RANK,
|
||||
@ -195,17 +203,26 @@ def initialize_model_parallel(
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PP_DEVICE_GROUP
|
||||
global _PP_DEVICE_GROUP, _PP_CPU_GROUP
|
||||
global _PP_PYNCCL_COMMUNICATOR
|
||||
global _PP_GLOBAL_RANKS
|
||||
assert _PP_DEVICE_GROUP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if rank in ranks:
|
||||
_PP_DEVICE_GROUP = group
|
||||
_PP_CPU_GROUP = cpu_group
|
||||
_PP_GLOBAL_RANKS = ranks
|
||||
|
||||
if pipeline_model_parallel_size > 1:
|
||||
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||
group=_PP_CPU_GROUP,
|
||||
device=_LOCAL_RANK,
|
||||
)
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group():
|
||||
return _PP_DEVICE_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_cpu_group():
|
||||
"""Get the pipeline model parallel cpu group the caller rank belongs to."""
|
||||
assert _PP_CPU_GROUP is not None, (
|
||||
"pipeline model parallel cpu group is not initialized")
|
||||
return _PP_CPU_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return torch.distributed.get_world_size(
|
||||
|
Loading…
x
Reference in New Issue
Block a user