[Distributed] Add send and recv helpers (#5719)

This commit is contained in:
Murali Andoorveedu 2024-06-23 17:42:28 -04:00 committed by GitHub
parent 6c916ac8a8
commit 5d4d90536f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 278 additions and 24 deletions

View File

@ -8,12 +8,11 @@ import pytest
import ray import ray
import torch import torch
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from ..utils import (init_test_distributed_environment, from ..utils import init_test_distributed_environment, multi_process_parallel
multi_process_tensor_parallel)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"]) assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}
if not get_pp_group().is_first_rank:
recv_dict = get_pp_group().recv_tensor_dict()
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(test_dict)
if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
if not get_pp_group().is_first_rank:
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
if not get_pp_group().is_last_rank:
get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker broadcast_tensor_dict_test_worker
]) ])
def test_multi_process_tensor_parallel(tp_size, test_target): def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target) multi_process_parallel(tp_size, 1, test_target)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)

View File

@ -12,8 +12,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group, graph_capture) get_tp_group, graph_capture)
from ..utils import (ensure_model_parallel_initialized, from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, init_test_distributed_environment, multi_process_parallel)
multi_process_tensor_parallel)
random.seed(42) random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

View File

@ -168,9 +168,13 @@ def send_recv_worker_fn():
dtype=torch.float32).cuda(pynccl_comm.rank) dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True): with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0: if pynccl_comm.rank == 0:
pynccl_comm.send(tensor) pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor) pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
result = tensor.mean().cpu().item() result = tensor.mean().cpu().item()
assert result == 1 assert result == 1
@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
device=device) device=device)
with pynccl_comm.change_state(enable=True): with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor) pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else: else:
pynccl_comm.recv(tensor) pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
result = tensor.mean().cpu().item() result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]: if torch.distributed.get_rank() in [0, 2]:
assert result == 1 assert result == 1

View File

@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_tensor_parallel( def multi_process_parallel(
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
test_target, test_target,

View File

@ -121,10 +121,7 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream)) cudaStream_t(stream.cuda_stream))
def send(self, def send(self, tensor: torch.Tensor, dst: int, stream=None):
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
if self.disabled: if self.disabled:
return return
assert tensor.device == self.device, ( assert tensor.device == self.device, (
@ -132,16 +129,11 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst, ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
def recv(self, def recv(self, tensor: torch.Tensor, src: int, stream=None):
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
if self.disabled: if self.disabled:
return return
assert tensor.device == self.device, ( assert tensor.device == self.device, (
@ -149,8 +141,6 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src, ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))

View File

@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps. steps.
""" """
import contextlib import contextlib
import pickle
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
@ -28,6 +29,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
@ -180,6 +182,16 @@ class GroupCoordinator:
"""Return the global rank of the last process in the group""" """Return the global rank of the last process in the group"""
return self.ranks[-1] return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property @property
def next_rank(self): def next_rank(self):
"""Return the global rank of the process that follows the caller""" """Return the global rank of the process that follows the caller"""
@ -374,6 +386,70 @@ class GroupCoordinator:
group=self.device_group) group=self.device_group)
return obj_list return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")
# Send object size
torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank, (
"Invalid source rank. Source rank is the same as the current rank."
)
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=src,
group=self.cpu_group)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")
rank_object = torch.distributed.recv(object_tensor,
src=src,
group=self.cpu_group)
assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict( def broadcast_tensor_dict(
self, self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
@ -459,6 +535,88 @@ class GroupCoordinator:
async_handle.wait() async_handle.wait()
return tensor_dict return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = self.next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = self.prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self): def barrier(self):
"""Barrier synchronization among the group. """Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is NOTE: don't use `device_group` here! `barrier` in NCCL is
@ -468,6 +626,35 @@ class GroupCoordinator:
""" """
torch.distributed.barrier(group=self.cpu_group) torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = self.next_rank
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = self.prev_rank
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self): def destroy(self):
if self.device_group is not None: if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group) torch.distributed.destroy_process_group(self.device_group)