[core][distributed] initialization from StatelessProcessGroup (#10986)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-12 01:04:19 -08:00 committed by GitHub
parent 8195824206
commit 62de37a38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 153 additions and 69 deletions

View File

@ -432,11 +432,11 @@ steps:
- tests/distributed/ - tests/distributed/
commands: commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- label: Distributed Tests (2 GPUs) # 40min - label: Distributed Tests (2 GPUs) # 40min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
@ -455,7 +455,7 @@ steps:
commands: commands:
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'

View File

@ -3,11 +3,32 @@ import os
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port
if __name__ == "__main__": if __name__ == "__main__":
dist.init_process_group(backend="gloo") dist.init_process_group(backend="gloo")
test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0))
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
test_result = all(in_the_same_node_as(pg, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}" assert test_result == expected, \
print("Same node test passed!") f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!")
else:
print("Same node test passed! when using StatelessProcessGroup!")

View File

@ -7,7 +7,8 @@ import numpy as np
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
@ -54,23 +55,44 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
writer_rank = 2 writer_rank = 2
broadcaster = MessageQueue.create_from_process_group( broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank) pg, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank: if rank == writer_rank:
seed = random.randint(0, 1000) seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank) dist.broadcast_object_list([seed], writer_rank)
else: else:
recv = [None] recv = [None]
dist.broadcast_object_list(recv, writer_rank) dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore seed = recv[0] # type: ignore
if pg == dist.group.WORLD:
dist.barrier() dist.barrier()
else:
pg.barrier()
# in case we find a race condition # in case we find a race condition
# print the seed so that we can reproduce the error # print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}") print(f"Rank {rank} got seed {seed}")
# test broadcasting with about 400MB of data # test broadcasting with about 400MB of data
N = 10_000 N = 10_000
if dist.get_rank() == writer_rank: if rank == writer_rank:
arrs = get_arrays(N, seed) arrs = get_arrays(N, seed)
for x in arrs: for x in arrs:
broadcaster.broadcast_object(x) broadcaster.broadcast_object(x)
@ -81,7 +103,13 @@ def worker_fn():
y = broadcaster.broadcast_object(None) y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y) assert np.array_equal(x, y)
time.sleep(random.random() / 1000) time.sleep(random.random() / 1000)
if pg == dist.group.WORLD:
dist.barrier() dist.barrier()
print("torch distributed passed the test!")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")
def test_shm_broadcast(): def test_shm_broadcast():

View File

@ -5,7 +5,7 @@ import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
@ -476,13 +477,19 @@ class MessageQueue:
return self.dequeue() return self.dequeue()
@staticmethod @staticmethod
def create_from_process_group(pg: ProcessGroup, def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes, max_chunk_bytes,
max_chunks, max_chunks,
writer_rank=0) -> "MessageQueue": writer_rank=0) -> "MessageQueue":
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg) group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg) group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg) global_ranks = dist.get_process_group_ranks(pg)
else:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank) status = in_the_same_node_as(pg, source_rank=writer_rank)
@ -500,15 +507,21 @@ class MessageQueue:
max_chunks=max_chunks, max_chunks=max_chunks,
) )
handle = buffer_io.export_handle() handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle], dist.broadcast_object_list([handle],
src=global_ranks[writer_rank], src=global_ranks[writer_rank],
group=pg) group=pg)
else: else:
pg.broadcast_obj(handle, writer_rank)
else:
if isinstance(pg, ProcessGroup):
recv = [None] recv = [None]
dist.broadcast_object_list(recv, dist.broadcast_object_list(recv,
src=global_ranks[writer_rank], src=global_ranks[writer_rank],
group=pg) group=pg)
handle = recv[0] # type: ignore handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank) buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready() buffer_io.wait_until_ready()
return buffer_io return buffer_io

View File

@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op from vllm.utils import direct_register_custom_op, supports_custom_op
@ -1191,12 +1192,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
""" """
This is a collective operation that returns if each rank is in the same node This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory). memory system (shared access to shared memory).
""" """
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend( assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, ( pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.") "in_the_same_node_as should be tested with a non-NCCL group.")
@ -1204,11 +1207,15 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
rank = torch.distributed.get_rank(group=pg) rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg) world_size = torch.distributed.get_world_size(group=pg)
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
# global ranks of the processes in the group # global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg) ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
magic_message = b"magic_message" magic_message = b"magic_message"
shm = None shm = None
@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
# create a shared memory segment # create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128) shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name], if isinstance(pg, ProcessGroup):
src=ranks[source_rank], torch.distributed.broadcast_object_list(
group=pg) [shm.name], src=ranks[source_rank], group=pg)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1 is_in_the_same_node[rank] = 1
else: else:
# try to open the shared memory segment # try to open the shared memory segment
if isinstance(pg, ProcessGroup):
recv = [None] recv = [None]
torch.distributed.broadcast_object_list(recv, torch.distributed.broadcast_object_list(
src=ranks[source_rank], recv, src=ranks[source_rank], group=pg)
group=pg)
name = recv[0] name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338 # fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not # Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround. # created by the process. The following patch is a workaround.
@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
if shm: if shm:
shm.close() shm.close()
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg) torch.distributed.barrier(group=pg)
else:
pg.barrier()
# clean up the shared memory segment # clean up the shared memory segment
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
if rank == source_rank and shm: if rank == source_rank and shm:
shm.unlink() shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return [x == 1 for x in is_in_the_same_node.tolist()] if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]