[core][distributed] initialization from StatelessProcessGroup (#10986)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
8195824206
commit
62de37a38e
@ -432,11 +432,11 @@ steps:
|
|||||||
- tests/distributed/
|
- tests/distributed/
|
||||||
commands:
|
commands:
|
||||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
|
||||||
- label: Distributed Tests (2 GPUs) # 40min
|
- label: Distributed Tests (2 GPUs) # 40min
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
@ -455,7 +455,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||||
# Avoid importing model tests that cause CUDA reinitialization error
|
# Avoid importing model tests that cause CUDA reinitialization error
|
||||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
|
@ -3,11 +3,32 @@ import os
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
|
from vllm.utils import get_ip, get_open_port
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dist.init_process_group(backend="gloo")
|
dist.init_process_group(backend="gloo")
|
||||||
test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0))
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
port = get_open_port()
|
||||||
|
ip = get_ip()
|
||||||
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
|
else:
|
||||||
|
recv = [None, None]
|
||||||
|
dist.broadcast_object_list(recv, src=0)
|
||||||
|
ip, port = recv
|
||||||
|
|
||||||
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||||
|
dist.get_world_size())
|
||||||
|
|
||||||
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
|
test_result = all(in_the_same_node_as(pg, source_rank=0))
|
||||||
|
|
||||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
assert test_result == expected, \
|
||||||
print("Same node test passed!")
|
f"Expected {expected}, got {test_result}"
|
||||||
|
if pg == dist.group.WORLD:
|
||||||
|
print("Same node test passed! when using torch distributed!")
|
||||||
|
else:
|
||||||
|
print("Same node test passed! when using StatelessProcessGroup!")
|
||||||
|
@ -7,7 +7,8 @@ import numpy as np
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
|
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
|
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
|
||||||
@ -54,23 +55,44 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
port = get_open_port()
|
||||||
|
ip = get_ip()
|
||||||
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
|
else:
|
||||||
|
recv = [None, None]
|
||||||
|
dist.broadcast_object_list(recv, src=0)
|
||||||
|
ip, port = recv
|
||||||
|
|
||||||
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||||
|
dist.get_world_size())
|
||||||
|
|
||||||
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
|
|
||||||
writer_rank = 2
|
writer_rank = 2
|
||||||
broadcaster = MessageQueue.create_from_process_group(
|
broadcaster = MessageQueue.create_from_process_group(
|
||||||
dist.group.WORLD, 40 * 1024, 2, writer_rank)
|
pg, 40 * 1024, 2, writer_rank)
|
||||||
if dist.get_rank() == writer_rank:
|
if rank == writer_rank:
|
||||||
seed = random.randint(0, 1000)
|
seed = random.randint(0, 1000)
|
||||||
dist.broadcast_object_list([seed], writer_rank)
|
dist.broadcast_object_list([seed], writer_rank)
|
||||||
else:
|
else:
|
||||||
recv = [None]
|
recv = [None]
|
||||||
dist.broadcast_object_list(recv, writer_rank)
|
dist.broadcast_object_list(recv, writer_rank)
|
||||||
seed = recv[0] # type: ignore
|
seed = recv[0] # type: ignore
|
||||||
|
|
||||||
|
if pg == dist.group.WORLD:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
pg.barrier()
|
||||||
|
|
||||||
# in case we find a race condition
|
# in case we find a race condition
|
||||||
# print the seed so that we can reproduce the error
|
# print the seed so that we can reproduce the error
|
||||||
print(f"Rank {dist.get_rank()} got seed {seed}")
|
print(f"Rank {rank} got seed {seed}")
|
||||||
# test broadcasting with about 400MB of data
|
# test broadcasting with about 400MB of data
|
||||||
N = 10_000
|
N = 10_000
|
||||||
if dist.get_rank() == writer_rank:
|
if rank == writer_rank:
|
||||||
arrs = get_arrays(N, seed)
|
arrs = get_arrays(N, seed)
|
||||||
for x in arrs:
|
for x in arrs:
|
||||||
broadcaster.broadcast_object(x)
|
broadcaster.broadcast_object(x)
|
||||||
@ -81,7 +103,13 @@ def worker_fn():
|
|||||||
y = broadcaster.broadcast_object(None)
|
y = broadcaster.broadcast_object(None)
|
||||||
assert np.array_equal(x, y)
|
assert np.array_equal(x, y)
|
||||||
time.sleep(random.random() / 1000)
|
time.sleep(random.random() / 1000)
|
||||||
|
|
||||||
|
if pg == dist.group.WORLD:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
print("torch distributed passed the test!")
|
||||||
|
else:
|
||||||
|
pg.barrier()
|
||||||
|
print("StatelessProcessGroup passed the test!")
|
||||||
|
|
||||||
|
|
||||||
def test_shm_broadcast():
|
def test_shm_broadcast():
|
||||||
|
@ -5,7 +5,7 @@ import time
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
|
|||||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||||
|
|
||||||
@ -476,13 +477,19 @@ class MessageQueue:
|
|||||||
return self.dequeue()
|
return self.dequeue()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_from_process_group(pg: ProcessGroup,
|
def create_from_process_group(pg: Union[ProcessGroup,
|
||||||
|
StatelessProcessGroup],
|
||||||
max_chunk_bytes,
|
max_chunk_bytes,
|
||||||
max_chunks,
|
max_chunks,
|
||||||
writer_rank=0) -> "MessageQueue":
|
writer_rank=0) -> "MessageQueue":
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
group_rank = dist.get_rank(pg)
|
group_rank = dist.get_rank(pg)
|
||||||
group_world_size = dist.get_world_size(pg)
|
group_world_size = dist.get_world_size(pg)
|
||||||
global_ranks = dist.get_process_group_ranks(pg)
|
global_ranks = dist.get_process_group_ranks(pg)
|
||||||
|
else:
|
||||||
|
group_rank = pg.rank
|
||||||
|
group_world_size = pg.world_size
|
||||||
|
global_ranks = list(range(pg.world_size))
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||||
@ -500,15 +507,21 @@ class MessageQueue:
|
|||||||
max_chunks=max_chunks,
|
max_chunks=max_chunks,
|
||||||
)
|
)
|
||||||
handle = buffer_io.export_handle()
|
handle = buffer_io.export_handle()
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
dist.broadcast_object_list([handle],
|
dist.broadcast_object_list([handle],
|
||||||
src=global_ranks[writer_rank],
|
src=global_ranks[writer_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
else:
|
else:
|
||||||
|
pg.broadcast_obj(handle, writer_rank)
|
||||||
|
else:
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
recv = [None]
|
recv = [None]
|
||||||
dist.broadcast_object_list(recv,
|
dist.broadcast_object_list(recv,
|
||||||
src=global_ranks[writer_rank],
|
src=global_ranks[writer_rank],
|
||||||
group=pg)
|
group=pg)
|
||||||
handle = recv[0] # type: ignore
|
handle = recv[0] # type: ignore
|
||||||
|
else:
|
||||||
|
handle = pg.broadcast_obj(None, writer_rank)
|
||||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||||
buffer_io.wait_until_ready()
|
buffer_io.wait_until_ready()
|
||||||
return buffer_io
|
return buffer_io
|
||||||
|
@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
|
|||||||
|
|
||||||
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
from vllm.utils import direct_register_custom_op, supports_custom_op
|
||||||
@ -1191,12 +1192,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||||
|
source_rank: int = 0) -> List[bool]:
|
||||||
"""
|
"""
|
||||||
This is a collective operation that returns if each rank is in the same node
|
This is a collective operation that returns if each rank is in the same node
|
||||||
as the source rank. It tests if processes are attached to the same
|
as the source rank. It tests if processes are attached to the same
|
||||||
memory system (shared access to shared memory).
|
memory system (shared access to shared memory).
|
||||||
"""
|
"""
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
assert torch.distributed.get_backend(
|
assert torch.distributed.get_backend(
|
||||||
pg) != torch.distributed.Backend.NCCL, (
|
pg) != torch.distributed.Backend.NCCL, (
|
||||||
"in_the_same_node_as should be tested with a non-NCCL group.")
|
"in_the_same_node_as should be tested with a non-NCCL group.")
|
||||||
@ -1204,11 +1207,15 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
|||||||
rank = torch.distributed.get_rank(group=pg)
|
rank = torch.distributed.get_rank(group=pg)
|
||||||
world_size = torch.distributed.get_world_size(group=pg)
|
world_size = torch.distributed.get_world_size(group=pg)
|
||||||
|
|
||||||
# local tensor in each process to store the result
|
|
||||||
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
|
||||||
|
|
||||||
# global ranks of the processes in the group
|
# global ranks of the processes in the group
|
||||||
ranks = torch.distributed.get_process_group_ranks(pg)
|
ranks = torch.distributed.get_process_group_ranks(pg)
|
||||||
|
else:
|
||||||
|
rank = pg.rank
|
||||||
|
world_size = pg.world_size
|
||||||
|
ranks = list(range(world_size))
|
||||||
|
|
||||||
|
# local tensor in each process to store the result
|
||||||
|
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
||||||
|
|
||||||
magic_message = b"magic_message"
|
magic_message = b"magic_message"
|
||||||
shm = None
|
shm = None
|
||||||
@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
|||||||
# create a shared memory segment
|
# create a shared memory segment
|
||||||
shm = shared_memory.SharedMemory(create=True, size=128)
|
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||||
shm.buf[:len(magic_message)] = magic_message
|
shm.buf[:len(magic_message)] = magic_message
|
||||||
torch.distributed.broadcast_object_list([shm.name],
|
if isinstance(pg, ProcessGroup):
|
||||||
src=ranks[source_rank],
|
torch.distributed.broadcast_object_list(
|
||||||
group=pg)
|
[shm.name], src=ranks[source_rank], group=pg)
|
||||||
|
else:
|
||||||
|
pg.broadcast_obj(shm.name, src=source_rank)
|
||||||
is_in_the_same_node[rank] = 1
|
is_in_the_same_node[rank] = 1
|
||||||
else:
|
else:
|
||||||
# try to open the shared memory segment
|
# try to open the shared memory segment
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
recv = [None]
|
recv = [None]
|
||||||
torch.distributed.broadcast_object_list(recv,
|
torch.distributed.broadcast_object_list(
|
||||||
src=ranks[source_rank],
|
recv, src=ranks[source_rank], group=pg)
|
||||||
group=pg)
|
|
||||||
name = recv[0]
|
name = recv[0]
|
||||||
|
else:
|
||||||
|
name = pg.broadcast_obj(None, src=source_rank)
|
||||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||||
# Python incorrectly tracks shared memory even if it is not
|
# Python incorrectly tracks shared memory even if it is not
|
||||||
# created by the process. The following patch is a workaround.
|
# created by the process. The following patch is a workaround.
|
||||||
@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
|||||||
if shm:
|
if shm:
|
||||||
shm.close()
|
shm.close()
|
||||||
|
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
torch.distributed.barrier(group=pg)
|
torch.distributed.barrier(group=pg)
|
||||||
|
else:
|
||||||
|
pg.barrier()
|
||||||
|
|
||||||
# clean up the shared memory segment
|
# clean up the shared memory segment
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if rank == source_rank and shm:
|
if rank == source_rank and shm:
|
||||||
shm.unlink()
|
shm.unlink()
|
||||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
|
||||||
|
|
||||||
return [x == 1 for x in is_in_the_same_node.tolist()]
|
if isinstance(pg, ProcessGroup):
|
||||||
|
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||||
|
aggregated_data = is_in_the_same_node
|
||||||
|
else:
|
||||||
|
aggregated_data = torch.zeros_like(is_in_the_same_node)
|
||||||
|
for i in range(world_size):
|
||||||
|
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
|
||||||
|
aggregated_data += rank_data
|
||||||
|
|
||||||
|
return [x == 1 for x in aggregated_data.tolist()]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user