[Distributed] Add send and recv helpers (#5719)

This commit is contained in:
Murali Andoorveedu 2024-06-23 17:42:28 -04:00 committed by GitHub
parent 6c916ac8a8
commit 5d4d90536f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 278 additions and 24 deletions

View File

@ -8,12 +8,11 @@ import pytest
import ray
import torch
from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import init_test_distributed_environment, multi_process_parallel
@ray.remote(num_gpus=1, max_calls=1)
@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}
if not get_pp_group().is_first_rank:
recv_dict = get_pp_group().recv_tensor_dict()
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(test_dict)
if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
if not get_pp_group().is_first_rank:
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
if not get_pp_group().is_last_rank:
get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
multi_process_parallel(tp_size, 1, test_target)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)

View File

@ -12,8 +12,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group, graph_capture)
from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_tensor_parallel)
init_test_distributed_environment, multi_process_parallel)
random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

View File

@ -168,9 +168,13 @@ def send_recv_worker_fn():
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0:
pynccl_comm.send(tensor)
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor)
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
result = tensor.mean().cpu().item()
assert result == 1
@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
device=device)
with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor)
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor)
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1

View File

@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel(
def multi_process_parallel(
tp_size: int,
pp_size: int,
test_target,

View File

@ -121,10 +121,7 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
@ -132,16 +129,11 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
@ -149,8 +141,6 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

View File

@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps.
"""
import contextlib
import pickle
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs
@ -180,6 +182,16 @@ class GroupCoordinator:
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
@ -374,6 +386,70 @@ class GroupCoordinator:
group=self.device_group)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")
# Send object size
torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank, (
"Invalid source rank. Source rank is the same as the current rank."
)
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=src,
group=self.cpu_group)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")
rank_object = torch.distributed.recv(object_tensor,
src=src,
group=self.cpu_group)
assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
@ -459,6 +535,88 @@ class GroupCoordinator:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = self.next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = self.prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
@ -468,6 +626,35 @@ class GroupCoordinator:
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = self.next_rank
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = self.prev_rank
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)