2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-11-11 11:54:59 -08:00
|
|
|
import socket
|
|
|
|
|
2024-11-06 16:42:09 -08:00
|
|
|
import pytest
|
2024-06-13 16:06:49 -07:00
|
|
|
import ray
|
2024-11-06 16:42:09 -08:00
|
|
|
import torch
|
2024-06-13 16:06:49 -07:00
|
|
|
|
2024-06-25 17:56:15 -05:00
|
|
|
import vllm.envs as envs
|
2024-11-11 09:02:14 -08:00
|
|
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
|
|
from vllm.distributed.utils import StatelessProcessGroup
|
2024-11-11 11:54:59 -08:00
|
|
|
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
2024-06-25 17:56:15 -05:00
|
|
|
update_environment_variables)
|
2024-06-13 16:06:49 -07:00
|
|
|
|
2024-11-06 16:42:09 -08:00
|
|
|
from ..utils import multi_gpu_test
|
|
|
|
|
2024-06-13 16:06:49 -07:00
|
|
|
|
|
|
|
@ray.remote
|
2024-06-15 12:45:31 +08:00
|
|
|
class _CUDADeviceCountStatelessTestActor:
|
2024-06-13 16:06:49 -07:00
|
|
|
|
|
|
|
def get_count(self):
|
|
|
|
return cuda_device_count_stateless()
|
|
|
|
|
|
|
|
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
2024-06-25 17:56:15 -05:00
|
|
|
update_environment_variables(
|
|
|
|
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
2024-06-13 16:06:49 -07:00
|
|
|
|
|
|
|
def get_cuda_visible_devices(self):
|
2024-06-25 17:56:15 -05:00
|
|
|
return envs.CUDA_VISIBLE_DEVICES
|
2024-06-13 16:06:49 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_cuda_device_count_stateless():
|
|
|
|
"""Test that cuda_device_count_stateless changes return value if
|
|
|
|
CUDA_VISIBLE_DEVICES is changed."""
|
2024-06-15 12:45:31 +08:00
|
|
|
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
|
|
|
num_gpus=2).remote()
|
2024-11-06 16:42:09 -08:00
|
|
|
assert len(
|
|
|
|
sorted(ray.get(
|
|
|
|
actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
2024-06-13 16:06:49 -07:00
|
|
|
assert ray.get(actor.get_count.remote()) == 2
|
|
|
|
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
|
|
|
assert ray.get(actor.get_count.remote()) == 1
|
|
|
|
ray.get(actor.set_cuda_visible_devices.remote(""))
|
|
|
|
assert ray.get(actor.get_count.remote()) == 0
|
2024-11-06 16:42:09 -08:00
|
|
|
|
|
|
|
|
2024-11-11 11:54:59 -08:00
|
|
|
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
2024-11-12 17:36:08 -08:00
|
|
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port1,
|
2024-11-06 16:42:09 -08:00
|
|
|
rank=rank,
|
2024-11-11 09:02:14 -08:00
|
|
|
world_size=WORLD_SIZE)
|
2024-11-06 16:42:09 -08:00
|
|
|
if rank <= 2:
|
2024-11-12 17:36:08 -08:00
|
|
|
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port2,
|
|
|
|
rank=rank,
|
|
|
|
world_size=3)
|
2024-11-06 16:42:09 -08:00
|
|
|
data = torch.tensor([rank])
|
2024-11-11 09:02:14 -08:00
|
|
|
data = pg1.broadcast_obj(data, src=2)
|
|
|
|
assert data.item() == 2
|
2024-11-06 16:42:09 -08:00
|
|
|
if rank <= 2:
|
2024-11-11 09:02:14 -08:00
|
|
|
data = torch.tensor([rank + 1])
|
|
|
|
data = pg2.broadcast_obj(data, src=2)
|
|
|
|
assert data.item() == 3
|
|
|
|
pg2.barrier()
|
|
|
|
pg1.barrier()
|
2024-11-06 16:42:09 -08:00
|
|
|
|
|
|
|
|
2024-11-11 11:54:59 -08:00
|
|
|
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
2024-11-11 09:02:14 -08:00
|
|
|
torch.cuda.set_device(rank)
|
2024-11-12 17:36:08 -08:00
|
|
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port1,
|
2024-11-06 16:42:09 -08:00
|
|
|
rank=rank,
|
2024-11-11 09:02:14 -08:00
|
|
|
world_size=WORLD_SIZE)
|
|
|
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
2024-11-06 16:42:09 -08:00
|
|
|
if rank <= 2:
|
2024-11-12 17:36:08 -08:00
|
|
|
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port2,
|
|
|
|
rank=rank,
|
|
|
|
world_size=3)
|
2024-11-11 09:02:14 -08:00
|
|
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
2024-11-06 16:42:09 -08:00
|
|
|
data = torch.tensor([rank]).cuda()
|
2024-11-11 09:02:14 -08:00
|
|
|
pynccl1.all_reduce(data)
|
|
|
|
pg1.barrier()
|
|
|
|
torch.cuda.synchronize()
|
2024-11-06 16:42:09 -08:00
|
|
|
if rank <= 2:
|
2024-11-11 09:02:14 -08:00
|
|
|
pynccl2.all_reduce(data)
|
|
|
|
pg2.barrier()
|
|
|
|
torch.cuda.synchronize()
|
2024-11-06 16:42:09 -08:00
|
|
|
item = data[0].item()
|
|
|
|
print(f"rank: {rank}, item: {item}")
|
|
|
|
if rank == 3:
|
|
|
|
assert item == 6
|
|
|
|
else:
|
|
|
|
assert item == 18
|
|
|
|
|
|
|
|
|
2024-11-11 11:54:59 -08:00
|
|
|
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
2024-11-12 17:36:08 -08:00
|
|
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port1,
|
2024-11-11 09:02:14 -08:00
|
|
|
rank=rank,
|
|
|
|
world_size=WORLD_SIZE)
|
|
|
|
if rank == 2:
|
|
|
|
pg1.broadcast_obj("secret", src=2)
|
|
|
|
else:
|
|
|
|
obj = pg1.broadcast_obj(None, src=2)
|
|
|
|
assert obj == "secret"
|
|
|
|
pg1.barrier()
|
|
|
|
|
|
|
|
|
2024-11-11 11:54:59 -08:00
|
|
|
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
2024-11-12 17:36:08 -08:00
|
|
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
|
|
|
port=port1,
|
2024-11-11 09:02:14 -08:00
|
|
|
rank=rank,
|
|
|
|
world_size=WORLD_SIZE)
|
|
|
|
data = pg1.all_gather_obj(rank)
|
|
|
|
assert data == list(range(WORLD_SIZE))
|
|
|
|
pg1.barrier()
|
|
|
|
|
|
|
|
|
2024-11-14 00:23:39 -08:00
|
|
|
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
2024-11-06 16:42:09 -08:00
|
|
|
@multi_gpu_test(num_gpus=4)
|
2024-11-11 09:02:14 -08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
|
|
|
def test_stateless_process_group(worker):
|
2024-11-11 11:54:59 -08:00
|
|
|
port1 = get_open_port()
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
|
|
s.bind(("", port1))
|
|
|
|
port2 = get_open_port()
|
2024-11-06 16:42:09 -08:00
|
|
|
WORLD_SIZE = 4
|
|
|
|
from multiprocessing import get_context
|
|
|
|
ctx = get_context("fork")
|
|
|
|
processes = []
|
|
|
|
for i in range(WORLD_SIZE):
|
|
|
|
rank = i
|
2024-11-11 11:54:59 -08:00
|
|
|
processes.append(
|
|
|
|
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
2024-11-06 16:42:09 -08:00
|
|
|
for p in processes:
|
|
|
|
p.start()
|
|
|
|
for p in processes:
|
|
|
|
p.join()
|
|
|
|
for p in processes:
|
|
|
|
assert not p.exitcode
|
|
|
|
print("All processes finished.")
|