[Core][1/N] Support send/recv in PyNCCL Groups (#4988)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
2ba80bed27
commit
5eda2ea02a
@ -3,6 +3,7 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
graph_capture, tensor_model_parallel_all_reduce)
|
graph_capture, tensor_model_parallel_all_reduce)
|
||||||
@ -68,7 +69,7 @@ def test_pynccl():
|
|||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def multiple_tp_worker_fn():
|
def multiple_allreduce_worker_fn():
|
||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
groups = [
|
groups = [
|
||||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||||
@ -92,14 +93,14 @@ def multiple_tp_worker_fn():
|
|||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
def test_pynccl_multiple_tp():
|
def test_pynccl_multiple_allreduce():
|
||||||
# this tests pynccl for multiple tp groups, in a standalone way
|
# this tests pynccl for multiple tp groups, in a standalone way
|
||||||
# i.e. call `pynccl_comm.all_reduce` directly
|
# i.e. call `pynccl_comm.all_reduce` directly
|
||||||
distributed_run(multiple_tp_worker_fn, 4)
|
distributed_run(multiple_allreduce_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def multiple_tp_with_vllm_worker_fn():
|
def multiple_allreduce_with_vllm_worker_fn():
|
||||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
ensure_model_parallel_initialized(2, 2)
|
ensure_model_parallel_initialized(2, 2)
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn():
|
|||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
def test_pynccl_multiple_tp_with_vllm():
|
def test_pynccl_multiple_allreduce_with_vllm():
|
||||||
# this tests pynccl for multiple tp groups, together with vllm
|
# this tests pynccl for multiple tp groups, together with vllm
|
||||||
# i.e. call `tensor_model_parallel_all_reduce`
|
# i.e. call `tensor_model_parallel_all_reduce`
|
||||||
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
|
distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph():
|
|||||||
distributed_run(worker_fn_with_cudagraph, 2)
|
distributed_run(worker_fn_with_cudagraph, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def send_recv_worker_fn():
|
||||||
|
pynccl_comm = PyNcclCommunicator()
|
||||||
|
if pynccl_comm.rank == 0:
|
||||||
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
|
else:
|
||||||
|
tensor = torch.empty(16, 1024, 1024,
|
||||||
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
|
if pynccl_comm.rank == 0:
|
||||||
|
pynccl_comm.send(tensor)
|
||||||
|
else:
|
||||||
|
pynccl_comm.recv(tensor)
|
||||||
|
result = tensor.mean().cpu().item()
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_send_recv():
|
||||||
|
distributed_run(send_recv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def multiple_send_recv_worker_fn():
|
||||||
|
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||||
|
groups = [
|
||||||
|
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||||
|
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
|
||||||
|
]
|
||||||
|
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||||
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
|
elif torch.distributed.get_rank() == 1:
|
||||||
|
tensor = 2 * torch.ones(
|
||||||
|
16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
|
else:
|
||||||
|
tensor = torch.empty(16,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
|
pynccl_comm.send(tensor)
|
||||||
|
else:
|
||||||
|
pynccl_comm.recv(tensor)
|
||||||
|
result = tensor.mean().cpu().item()
|
||||||
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
|
assert result == 1
|
||||||
|
else:
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
|
def test_pynccl_multiple_send_recv():
|
||||||
|
distributed_run(multiple_send_recv_worker_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
def test_ncclGetUniqueId():
|
def test_ncclGetUniqueId():
|
||||||
lib = NCCLLibrary()
|
lib = NCCLLibrary()
|
||||||
unique_id = lib.ncclGetUniqueId()
|
unique_id = lib.ncclGetUniqueId()
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from .parallel_state import (get_cpu_world_group,
|
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
|
||||||
get_tensor_model_parallel_group,
|
get_tensor_model_parallel_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -54,13 +54,19 @@ def graph_capture():
|
|||||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||||
# We always prioritize using custom all-reduce kernel but fall back
|
# We always prioritize using custom all-reduce kernel but fall back
|
||||||
# to PyTorch or pynccl if it is disabled or not supported.
|
# to PyTorch or pynccl if it is disabled or not supported.
|
||||||
pynccl_comm = get_tp_pynccl_communicator()
|
tp_pynccl_comm = get_tp_pynccl_communicator()
|
||||||
if pynccl_comm is None:
|
pp_pynccl_comm = get_pp_pynccl_communicator()
|
||||||
maybe_pynccl_context = nullcontext()
|
if not tp_pynccl_comm:
|
||||||
|
maybe_tp_pynccl_context = nullcontext()
|
||||||
else:
|
else:
|
||||||
maybe_pynccl_context = pynccl_comm.change_state(
|
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
|
||||||
enable=True, stream=torch.cuda.current_stream())
|
enable=True, stream=torch.cuda.current_stream())
|
||||||
with maybe_pynccl_context:
|
if not pp_pynccl_comm:
|
||||||
|
maybe_pp_pynccl_context = nullcontext()
|
||||||
|
else:
|
||||||
|
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
|
||||||
|
enable=True, stream=torch.cuda.current_stream())
|
||||||
|
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
|
||||||
yield graph_capture_context
|
yield graph_capture_context
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,6 +126,40 @@ class PyNcclCommunicator:
|
|||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def send(self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
dst: Optional[int] = None,
|
||||||
|
stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
assert tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
if dst is None:
|
||||||
|
dst = (self.rank + 1) % self.world_size
|
||||||
|
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||||
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def recv(self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
src: Optional[int] = None,
|
||||||
|
stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
assert tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
if src is None:
|
||||||
|
src = (self.rank - 1) % self.world_size
|
||||||
|
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_state(self,
|
def change_state(self,
|
||||||
enable: Optional[bool] = None,
|
enable: Optional[bool] = None,
|
||||||
|
@ -151,6 +151,22 @@ class NCCLLibrary:
|
|||||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||||
]),
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclSend(
|
||||||
|
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||||
|
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||||
|
Function("ncclSend", ncclResult_t, [
|
||||||
|
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||||
|
ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclRecv(
|
||||||
|
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||||
|
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||||
|
Function("ncclRecv", ncclResult_t, [
|
||||||
|
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||||
|
ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
# be cautious! this is a collective call, it will block until all
|
# be cautious! this is a collective call, it will block until all
|
||||||
# processes in the communicator have called this function.
|
# processes in the communicator have called this function.
|
||||||
# because Python object destruction can happen in random order,
|
# because Python object destruction can happen in random order,
|
||||||
@ -248,6 +264,16 @@ class NCCLLibrary:
|
|||||||
datatype, op, comm,
|
datatype, op, comm,
|
||||||
stream))
|
stream))
|
||||||
|
|
||||||
|
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||||
|
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||||
|
dest, comm, stream))
|
||||||
|
|
||||||
|
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||||
|
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||||
|
comm, stream))
|
||||||
|
|
||||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
@ -22,6 +22,8 @@ _TP_PYNCCL_COMMUNICATOR = None
|
|||||||
_TP_CA_COMMUNICATOR = None
|
_TP_CA_COMMUNICATOR = None
|
||||||
# Pipeline model parallel group that the current rank belongs to.
|
# Pipeline model parallel group that the current rank belongs to.
|
||||||
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||||
|
_PP_CPU_GROUP: Optional[ProcessGroup] = None
|
||||||
|
_PP_PYNCCL_COMMUNICATOR = None
|
||||||
|
|
||||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
# when people blindly call `torch.distributed.all_reduce` etc,
|
||||||
# it will use this group. It is initialized with the `backend`
|
# it will use this group. It is initialized with the `backend`
|
||||||
@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool):
|
|||||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||||
|
|
||||||
|
|
||||||
|
def get_pp_pynccl_communicator():
|
||||||
|
global _PP_PYNCCL_COMMUNICATOR
|
||||||
|
return _PP_PYNCCL_COMMUNICATOR
|
||||||
|
|
||||||
|
|
||||||
def get_tp_pynccl_communicator():
|
def get_tp_pynccl_communicator():
|
||||||
global _TP_PYNCCL_COMMUNICATOR
|
global _TP_PYNCCL_COMMUNICATOR
|
||||||
return _TP_PYNCCL_COMMUNICATOR
|
return _TP_PYNCCL_COMMUNICATOR
|
||||||
@ -180,6 +187,7 @@ def initialize_model_parallel(
|
|||||||
_TP_CPU_GROUP = cpu_group
|
_TP_CPU_GROUP = cpu_group
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
if tensor_model_parallel_size > 1:
|
||||||
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||||
group=_TP_CPU_GROUP,
|
group=_TP_CPU_GROUP,
|
||||||
device=_LOCAL_RANK,
|
device=_LOCAL_RANK,
|
||||||
@ -195,17 +203,26 @@ def initialize_model_parallel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
global _PP_DEVICE_GROUP
|
global _PP_DEVICE_GROUP, _PP_CPU_GROUP
|
||||||
|
global _PP_PYNCCL_COMMUNICATOR
|
||||||
global _PP_GLOBAL_RANKS
|
global _PP_GLOBAL_RANKS
|
||||||
assert _PP_DEVICE_GROUP is None, (
|
assert _PP_DEVICE_GROUP is None, (
|
||||||
"pipeline model parallel group is already initialized")
|
"pipeline model parallel group is already initialized")
|
||||||
for i in range(num_pipeline_model_parallel_groups):
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||||
group = torch.distributed.new_group(ranks, backend=backend)
|
group = torch.distributed.new_group(ranks, backend=backend)
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
if rank in ranks:
|
if rank in ranks:
|
||||||
_PP_DEVICE_GROUP = group
|
_PP_DEVICE_GROUP = group
|
||||||
|
_PP_CPU_GROUP = cpu_group
|
||||||
_PP_GLOBAL_RANKS = ranks
|
_PP_GLOBAL_RANKS = ranks
|
||||||
|
|
||||||
|
if pipeline_model_parallel_size > 1:
|
||||||
|
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
||||||
|
group=_PP_CPU_GROUP,
|
||||||
|
device=_LOCAL_RANK,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size: int,
|
tensor_model_parallel_size: int,
|
||||||
@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group():
|
|||||||
return _PP_DEVICE_GROUP
|
return _PP_DEVICE_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline_model_parallel_cpu_group():
|
||||||
|
"""Get the pipeline model parallel cpu group the caller rank belongs to."""
|
||||||
|
assert _PP_CPU_GROUP is not None, (
|
||||||
|
"pipeline model parallel cpu group is not initialized")
|
||||||
|
return _PP_CPU_GROUP
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size():
|
def get_tensor_model_parallel_world_size():
|
||||||
"""Return world size for the tensor model parallel group."""
|
"""Return world size for the tensor model parallel group."""
|
||||||
return torch.distributed.get_world_size(
|
return torch.distributed.get_world_size(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user