[distributed] add function to create ipc buffers directly (#10064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4089985552
commit
4be3a45158
@ -510,6 +510,7 @@ steps:
|
||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
|
||||
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
||||
|
59
tests/distributed/test_ca_buffer_sharing.py
Normal file
59
tests/distributed/test_ca_buffer_sharing.py
Normal file
@ -0,0 +1,59 @@
|
||||
# can only run on machines with p2p access across GPUs
|
||||
# can only run with torchrun:
|
||||
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py
|
||||
|
||||
import ctypes
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
|
||||
CustomAllreduce)
|
||||
|
||||
# create a cpu process group for communicating metadata (ipc handle)
|
||||
dist.init_process_group(backend="gloo")
|
||||
rank = local_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# every process sets its own device (differently)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaSetDevice(rank)
|
||||
|
||||
buffer_size_in_bytes = 1024
|
||||
byte_value = 2 # the value we write to the buffer for verification
|
||||
|
||||
pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)
|
||||
|
||||
print(f"Rank {rank} has pointers {pointers}")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if rank == 0:
|
||||
# the first rank tries to write to all buffers
|
||||
for p in pointers:
|
||||
pointer = ctypes.c_void_p(p)
|
||||
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
host_data = (ctypes.c_char * buffer_size_in_bytes)()
|
||||
|
||||
# all ranks read from all buffers, and check if the data is correct
|
||||
for p in pointers:
|
||||
pointer = ctypes.c_void_p(p)
|
||||
lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes)
|
||||
for i in range(buffer_size_in_bytes):
|
||||
assert ord(host_data[i]) == byte_value, (
|
||||
f"Rank {rank} failed"
|
||||
f" to verify buffer {p}. Expected {byte_value}, "
|
||||
f"got {ord(host_data[i])}")
|
||||
|
||||
print(f"Rank {rank} verified all buffers")
|
||||
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
CustomAllreduce.free_shared_buffer(pointers)
|
@ -1,3 +1,4 @@
|
||||
import ctypes
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
@ -7,6 +8,7 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
@ -174,6 +176,35 @@ class CustomAllreduce:
|
||||
offsets, rank, self.full_nvlink)
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None) -> List[int]:
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(
|
||||
lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user