vllm/tests/distributed/test_custom_all_reduce.py
youkaichao 63e7176f26
[Core][Refactor] move parallel_utils into vllm/distributed (#3950)
[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
2024-04-10 15:33:30 -07:00

85 lines
3.2 KiB
Python

import os
import random
import pytest
import ray
import torch
import torch.distributed as dist
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
for i, v in enumerate(test_sizes):
test_sizes[i] -= v % 8
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
dtype=dtype,
device=torch.cuda.current_device())
inp2 = torch.randint(1,
16, (sz, ),
dtype=dtype,
device=torch.cuda.current_device())
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
sz = 1024
custom_all_reduce.init_custom_all_reduce()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target)
if __name__ == "__main__":
multi_process_tensor_parallel(2, graph_allreduce)