[core][distributed] exact ray placement control (#12732)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
022bcc701a
commit
bc1bdecebf
@ -128,6 +128,7 @@ steps:
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/ray_placement.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -136,6 +137,7 @@ steps:
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
|
121
examples/offline_inference/ray_placement.py
Normal file
121
examples/offline_inference/ray_placement.py
Normal file
@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
a simple demonstration to show how to control
|
||||
the placement of the vLLM workers with Ray.
|
||||
The key is to set VLLM_RAY_PER_WORKER_GPUS and
|
||||
VLLM_RAY_BUNDLE_INDICES properly.
|
||||
"""
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
return current_platform.get_device_uuid(self.device.index)
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
|
||||
def __init__(self, *args, bundle_indices: list, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
# every worker will use 0.4 GPU, so that we can schedule
|
||||
# 2 instances on the same GPUs.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(
|
||||
map(str, bundle_indices))
|
||||
print(f"creating LLM with bundle_indices={bundle_indices}")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class RayTrainingActor:
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
# the argument for get_device_uuid is the index
|
||||
# of the GPU in the visible devices.
|
||||
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
|
||||
from vllm.platforms import current_platform
|
||||
return current_platform.get_device_uuid(0)
|
||||
|
||||
|
||||
# ray manages 4 GPUs
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
ray.init()
|
||||
|
||||
# we want to co-locate vLLM instance and the training actor
|
||||
# on the same set of GPUs.
|
||||
# the placement plan is as follows:
|
||||
# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2)
|
||||
# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2)
|
||||
|
||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||
ray.get(pg.ready())
|
||||
print(f"placement group has bundles {pg.bundle_specs=}")
|
||||
|
||||
training_actors = []
|
||||
training_actor_device_ids = []
|
||||
inference_engines = []
|
||||
inference_engine_device_ids = []
|
||||
|
||||
for bundle_index in [0, 1, 2, 3]:
|
||||
training_actor = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0.4,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_index,
|
||||
),
|
||||
)(RayTrainingActor).remote()
|
||||
training_actors.append(training_actor)
|
||||
device_id = ray.get(training_actor.report_device_id.remote())
|
||||
print(f"training actor {bundle_index} is on {device_id}")
|
||||
training_actor_device_ids.append(device_id)
|
||||
|
||||
for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
|
||||
# IMPORTANT: when creating vLLM instances, we need to
|
||||
# make sure there are no GPU activities on the target GPUs,
|
||||
# otherwise, they will interfere with the vLLM memory profiling,
|
||||
# and cause unexpected behaviors.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_cls=MyWorker,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.4,
|
||||
bundle_indices=bundle_indices,
|
||||
)
|
||||
inference_engines.append(llm)
|
||||
# don't call any method on the inference engine here,
|
||||
# otherwise it will block until the vLLM instance is created.
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())))
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
|
||||
# check the placement
|
||||
# the first two training actors should be
|
||||
# on the same GPUs as the first inference engine
|
||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# the last two training actors should be
|
||||
# on the same GPUs as the second inference engine
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
14
vllm/envs.py
14
vllm/envs.py
@ -85,6 +85,8 @@ if TYPE_CHECKING:
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -550,6 +552,18 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||
"VLLM_RAY_PER_WORKER_GPUS":
|
||||
lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")),
|
||||
|
||||
# Bundle indices for Ray, if it is set, it can control precisely
|
||||
# which indices are used for the Ray bundle, for every worker.
|
||||
# Format: comma-separated list of integers, e.g. "0,1,2,3"
|
||||
"VLLM_RAY_BUNDLE_INDICES":
|
||||
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
|
||||
|
||||
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
|
||||
# byte aligned for better performance, this increases the memory usage of
|
||||
# the cache. Currently this only affects MLA that results in non-256
|
||||
|
@ -129,13 +129,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
@ -155,12 +149,29 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
rank = 0
|
||||
bundle_indices: List[int]
|
||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||
# Use the bundle indices specified by the user.
|
||||
bundle_indices = list(
|
||||
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
||||
assert len(bundle_indices) == self.parallel_config.world_size, \
|
||||
("VLLM_RAY_BUNDLE_INDICES must have the same size"
|
||||
f" as the world size, but got {bundle_indices=} "
|
||||
f"and {self.parallel_config.world_size=}")
|
||||
assert len(set(bundle_indices)) == len(bundle_indices), \
|
||||
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
|
||||
f" but got {bundle_indices=}")
|
||||
else:
|
||||
# use the first N bundles that have GPU resources.
|
||||
bundle_indices = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if bundle.get(current_platform.ray_device_key, 0):
|
||||
bundle_indices.append(bundle_id)
|
||||
bundle_indices = bundle_indices[:self.parallel_config.world_size]
|
||||
|
||||
worker_metadata: List[RayWorkerMetaData] = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get(current_platform.ray_device_key, 0):
|
||||
continue
|
||||
driver_ip = get_ip()
|
||||
for rank, bundle_id in enumerate(bundle_indices):
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
@ -187,7 +198,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
rpc_rank=rank)
|
||||
worker_metadata.append(
|
||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
rank += 1
|
||||
|
||||
worker_ips = ray.get([
|
||||
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
|
@ -275,6 +275,14 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return cls._get_physical_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
return pynvml.nvmlDeviceGetUUID(handle)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
|
@ -183,6 +183,11 @@ class Platform:
|
||||
"""Get the name of a device."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user