# SPDX-License-Identifier: Apache-2.0 import multiprocessing import random import time from typing import List import numpy as np import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.utils import StatelessProcessGroup from vllm.utils import get_ip, get_open_port, update_environment_variables def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: np.random.seed(seed) sizes = np.random.randint(1, 10_000, n) # on average, each array will have 5k elements # with int64, each array will have 40kb return [np.random.randint(1, 100, i) for i in sizes] def distributed_run(fn, world_size): number_of_processes = world_size processes = [] for i in range(number_of_processes): env = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env['MASTER_ADDR'] = 'localhost' env['MASTER_PORT'] = '12345' p = multiprocessing.Process(target=fn, args=(env, )) processes.append(p) p.start() for p in processes: p.join() for p in processes: assert p.exitcode == 0 def worker_fn_wrapper(fn): # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) dist.init_process_group(backend="gloo") fn() return wrapped_fn @worker_fn_wrapper def worker_fn(): rank = dist.get_rank() if rank == 0: port = get_open_port() ip = get_ip() dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) ip, port = recv stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: writer_rank = 2 broadcaster = MessageQueue.create_from_process_group( pg, 40 * 1024, 2, writer_rank) if rank == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) else: recv = [None] dist.broadcast_object_list(recv, writer_rank) seed = recv[0] # type: ignore if pg == dist.group.WORLD: dist.barrier() else: pg.barrier() # in case we find a race condition # print the seed so that we can reproduce the error print(f"Rank {rank} got seed {seed}") # test broadcasting with about 400MB of data N = 10_000 if rank == writer_rank: arrs = get_arrays(N, seed) for x in arrs: broadcaster.broadcast_object(x) time.sleep(random.random() / 1000) else: arrs = get_arrays(N, seed) for x in arrs: y = broadcaster.broadcast_object(None) assert np.array_equal(x, y) time.sleep(random.random() / 1000) if pg == dist.group.WORLD: dist.barrier() print("torch distributed passed the test!") else: pg.barrier() print("StatelessProcessGroup passed the test!") def test_shm_broadcast(): distributed_run(worker_fn, 4)