[core][distributed] zmq fallback for broadcasting large objects (#6183)

[core][distributed] add zmq fallback for broadcasting large objects (#6183)
This commit is contained in:
youkaichao 2024-07-09 18:49:11 -07:00 committed by GitHub
parent 2416b26e11
commit da78caecfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 274 additions and 80 deletions

View File

@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1
outlines >= 0.0.43 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq

View File

@ -2,10 +2,11 @@ import os
import torch
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.distributed.parallel_state import in_the_same_node_as
torch.distributed.init_process_group(backend="gloo")
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
test_result = all(
in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"

View File

@ -6,8 +6,7 @@ from typing import List
import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables
@ -56,8 +55,8 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
@ -87,13 +86,3 @@ def worker_fn():
def test_shm_broadcast():
distributed_run(worker_fn, 4)
def test_singe_process():
buffer = ShmRingBuffer(1, 1024, 4)
reader = ShmRingBufferIO(buffer, reader_rank=0)
writer = ShmRingBufferIO(buffer, reader_rank=-1)
writer.enqueue([0])
writer.enqueue([1])
assert reader.dequeue() == [0]
assert reader.dequeue() == [1]

View File

@ -9,7 +9,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
@ -64,7 +64,7 @@ class CustomAllreduce:
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not is_in_the_same_node(group):
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"

View File

@ -1,16 +1,19 @@
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import Optional
from typing import List, Optional
from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
@ -135,18 +138,183 @@ class ShmRingBuffer:
yield buf
class ShmRingBufferIO:
@dataclass
class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
self.buffer = buffer
self.reader_rank = reader_rank
self._is_writer = self.reader_rank == -1
self._is_reader = not self._is_writer
if self._is_reader:
assert 0 <= self.reader_rank < buffer.n_reader, \
(f"Invalid reader rank {self.reader_rank} for buffer"
f" created with {buffer.n_reader} readers")
self.current_idx = 0
buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue:
def __init__(
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
else:
assert len(local_reader_ranks) == n_local_reader
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip()
context = Context()
if n_local_reader > 0:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.local_socket = context.socket(PUB)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True
self._is_local_reader = False
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)
def export_handle(self) -> Handle:
return self.handle
@staticmethod
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self = MessageQueue.__new__(MessageQueue)
self.handle = handle
self._is_writer = False
context = Context()
if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
self._is_remote_reader = False
self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "")
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
self.local_reader_rank = -1
self._is_local_reader = False
self._is_remote_reader = True
self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self
def wait_until_ready(self):
"""This is a collective operation. All processes (including the
readers and the writer) should call this function.
"""
if self._is_writer:
# wait for all readers to connect
# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
if self.n_local_reader > 0:
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
if self.n_remote_reader > 0:
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
recv = self.remote_socket.recv()
assert recv == b"READY"
@contextmanager
def acquire_write(self):
@ -201,12 +369,12 @@ class ShmRingBufferIO:
@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.reader_rank + 1]
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
# this block is either
@ -236,7 +404,7 @@ class ShmRingBufferIO:
# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break
@ -244,21 +412,36 @@ class ShmRingBufferIO:
def enqueue(self, obj):
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if len(serialized_obj) > self.buffer.max_chunk_bytes:
raise RuntimeError(
f"{len(serialized_obj)=} larger than the allowed value "
f"{self.buffer.max_chunk_bytes},"
"Please increase the max_chunk_bytes parameter.")
with self.acquire_write() as buf:
buf[:len(serialized_obj)] = serialized_obj
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self):
assert self._is_reader, "Only readers can dequeue"
with self.acquire_read() as buf:
# no need to know the size of serialized object
# pickle format itself contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf)
if self._is_local_reader:
overflow = False
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
else:
raise RuntimeError("Only readers can dequeue")
return obj
def broadcast_object(self, obj=None):
@ -272,24 +455,36 @@ class ShmRingBufferIO:
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "ShmRingBufferIO":
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
ranks_inside_group = list(range(group_world_size))
global_ranks = dist.get_process_group_ranks(pg)
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
buffer: ShmRingBuffer
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer],
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
handle = recv[0] # type: ignore
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io

View File

@ -124,7 +124,7 @@ class GroupCoordinator:
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
shm_broadcaster: Optional[Any] # shared memory broadcaster
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
self,
@ -133,6 +133,7 @@ class GroupCoordinator:
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_message_queue_broadcaster: bool = False,
):
self.rank = torch.distributed.get_rank()
@ -190,10 +191,10 @@ class GroupCoordinator:
self.ca_comm = None
from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBufferIO)
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
if use_message_queue_broadcaster and self.world_size > 1:
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6)
@property
@ -377,9 +378,9 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.shm_broadcaster is not None:
assert src == 0, "Shared memory broadcaster only supports src=0"
return self.shm_broadcaster.broadcast_object(obj)
if self.mq_broadcaster is not None:
assert src == 0, "Message queue broadcaster only supports src=0"
return self.mq_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list([obj],
src=self.ranks[src],
@ -696,8 +697,8 @@ class GroupCoordinator:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.shm_broadcaster is not None:
self.shm_broadcaster = None
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
_WORLD: Optional[GroupCoordinator] = None
@ -720,10 +721,12 @@ def init_world_group(ranks: List[int], local_rank: int,
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
return GroupCoordinator(
@ -732,6 +735,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)
@ -880,8 +884,12 @@ def initialize_model_parallel(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, backend)
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
@ -993,15 +1001,15 @@ def destroy_distributed_environment():
torch.distributed.destroy_process_group()
def is_in_the_same_node(pg: ProcessGroup):
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that checks if all processes in the group
are in the same node. It tests if all processes are attached to the same
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"is_in_the_same_node should be tested with a non-NCCL group.")
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
@ -1017,19 +1025,19 @@ def is_in_the_same_node(pg: ProcessGroup):
try:
with contextlib.suppress(OSError):
if rank == 0:
if rank == source_rank:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[0],
src=ranks[source_rank],
group=pg)
is_in_the_same_node[0] = 1
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[0],
src=ranks[source_rank],
group=pg)
name = recv[0]
# fix to https://stackoverflow.com/q/62748654/9191338
@ -1050,8 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup):
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == 0 and shm:
if rank == source_rank and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size
return [x == 1 for x in is_in_the_same_node.tolist()]