vllm/tests/distributed/test_utils.py
youkaichao 719c1ca468
[core][distributed] add stateless_init_process_group (#10072)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2024-11-06 16:42:09 -08:00

105 lines
3.5 KiB
Python

import pytest
import ray
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
from ..utils import multi_gpu_test
@ray.remote
class _CUDADeviceCountStatelessTestActor:
def get_count(self):
return cuda_device_count_stateless()
def set_cuda_visible_devices(self, cuda_visible_devices: str):
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
def get_cuda_visible_devices(self):
return envs.CUDA_VISIBLE_DEVICES
def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert len(
sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(","))) == 2
assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0
def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=WORLD_SIZE,
backend="gloo")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=3,
backend="gloo")
data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18
def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
WORLD_SIZE = 4
from multiprocessing import get_context
ctx = get_context("fork")
processes = []
for i in range(WORLD_SIZE):
rank = i
processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
for p in processes:
p.start()
for p in processes:
p.join()
for p in processes:
assert not p.exitcode
print("All processes finished.")