[core][distributed] initialization from StatelessProcessGroup (#10986)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-12 01:04:19 -08:00 committed by GitHub
parent 8195824206
commit 62de37a38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 153 additions and 69 deletions

View File

@ -432,11 +432,11 @@ steps:
- tests/distributed/
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- label: Distributed Tests (2 GPUs) # 40min
#mirror_hardwares: [amd]
@ -455,7 +455,7 @@ steps:
commands:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'

View File

@ -3,11 +3,32 @@ import os
import torch.distributed as dist
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port
if __name__ == "__main__":
dist.init_process_group(backend="gloo")
test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
print("Same node test passed!")
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
test_result = all(in_the_same_node_as(pg, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, \
f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!")
else:
print("Same node test passed! when using StatelessProcessGroup!")

View File

@ -7,7 +7,8 @@ import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
@ -54,34 +55,61 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
pg, 40 * 1024, 2, writer_rank)
if rank == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
if pg == dist.group.WORLD:
dist.barrier()
else:
pg.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {rank} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if rank == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
if pg == dist.group.WORLD:
dist.barrier()
print("torch distributed passed the test!")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")
def test_shm_broadcast():

View File

@ -5,7 +5,7 @@ import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from unittest.mock import patch
import torch
@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
@ -476,13 +477,19 @@ class MessageQueue:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: ProcessGroup,
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
else:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
@ -500,15 +507,21 @@ class MessageQueue:
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
else:
pg.broadcast_obj(handle, writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
if isinstance(pg, ProcessGroup):
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io

View File

@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
@ -1191,25 +1192,31 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
magic_message = b"magic_message"
shm = None
@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[source_rank],
group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[source_rank],
group=pg)
name = recv[0]
if isinstance(pg, ProcessGroup):
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg)
name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
if shm:
shm.close()
torch.distributed.barrier(group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg)
else:
pg.barrier()
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return [x == 1 for x in is_in_the_same_node.tolist()]
if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]