60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
# can only run on machines with p2p access across GPUs
|
|
# can only run with torchrun:
|
|
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py
|
|
|
|
import ctypes
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
|
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
|
CustomAllreduce)
|
|
|
|
# create a cpu process group for communicating metadata (ipc handle)
|
|
dist.init_process_group(backend="gloo")
|
|
rank = local_rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
# every process sets its own device (differently)
|
|
lib = CudaRTLibrary()
|
|
lib.cudaSetDevice(rank)
|
|
|
|
buffer_size_in_bytes = 1024
|
|
byte_value = 2 # the value we write to the buffer for verification
|
|
|
|
pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
|
|
|
|
print(f"Rank {rank} has pointers {pointers}")
|
|
|
|
dist.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
if rank == 0:
|
|
# the first rank tries to write to all buffers
|
|
for p in pointers:
|
|
pointer = ctypes.c_void_p(p)
|
|
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
|
|
|
|
dist.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
host_data = (ctypes.c_char * buffer_size_in_bytes)()
|
|
|
|
# all ranks read from all buffers, and check if the data is correct
|
|
for p in pointers:
|
|
pointer = ctypes.c_void_p(p)
|
|
lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes)
|
|
for i in range(buffer_size_in_bytes):
|
|
assert ord(host_data[i]) == byte_value, (
|
|
f"Rank {rank} failed"
|
|
f" to verify buffer {p}. Expected {byte_value}, "
|
|
f"got {ord(host_data[i])}")
|
|
|
|
print(f"Rank {rank} verified all buffers")
|
|
|
|
dist.barrier()
|
|
torch.cuda.synchronize()
|
|
|
|
CustomAllreduce.free_shared_buffer(pointers)
|