[Misc][TPU] Support TPU in initialize_ray_cluster (#6812)
This commit is contained in:
parent
71734f1bf2
commit
aa4867791e
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_ip, is_hip, is_xpu
|
||||
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -93,32 +93,38 @@ def initialize_ray_cluster(
|
||||
# Placement group is already set.
|
||||
return
|
||||
|
||||
device_str = "GPU" if not is_tpu() else "TPU"
|
||||
# Create placement group for worker processes
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
bundles = current_placement_group.bundle_specs
|
||||
# Verify that we can use the placement group.
|
||||
gpu_bundles = 0
|
||||
device_bundles = 0
|
||||
for bundle in bundles:
|
||||
bundle_gpus = bundle.get("GPU", 0)
|
||||
if bundle_gpus > 1:
|
||||
bundle_devices = bundle.get(device_str, 0)
|
||||
if bundle_devices > 1:
|
||||
raise ValueError(
|
||||
"Placement group bundle cannot have more than 1 GPU.")
|
||||
if bundle_gpus:
|
||||
gpu_bundles += 1
|
||||
if parallel_config.world_size > gpu_bundles:
|
||||
"Placement group bundle cannot have more than 1 "
|
||||
f"{device_str}.")
|
||||
if bundle_devices:
|
||||
device_bundles += 1
|
||||
if parallel_config.world_size > device_bundles:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the placement group.")
|
||||
f"The number of required {device_str}s exceeds the total "
|
||||
f"number of available {device_str}s in the placement group."
|
||||
f"Required number of devices: {parallel_config.world_size}. "
|
||||
f"Total number of devices: {device_bundles}.")
|
||||
else:
|
||||
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
||||
if parallel_config.world_size > num_gpus_in_cluster:
|
||||
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
|
||||
if parallel_config.world_size > num_devices_in_cluster:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the cluster.")
|
||||
f"The number of required {device_str}s exceeds the total "
|
||||
f"number of available {device_str}s in the placement group.")
|
||||
# Create a new placement group
|
||||
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
|
||||
placement_group_specs = ([{
|
||||
device_str: 1
|
||||
}] * parallel_config.world_size)
|
||||
current_placement_group = ray.util.placement_group(
|
||||
placement_group_specs)
|
||||
# Wait until PG is ready - this will block until all
|
||||
|
Loading…
x
Reference in New Issue
Block a user