[Misc][TPU] Support TPU in initialize_ray_cluster (#6812)
This commit is contained in:
parent
71734f1bf2
commit
aa4867791e
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
|||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import get_ip, is_hip, is_xpu
|
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -93,32 +93,38 @@ def initialize_ray_cluster(
|
|||||||
# Placement group is already set.
|
# Placement group is already set.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
device_str = "GPU" if not is_tpu() else "TPU"
|
||||||
# Create placement group for worker processes
|
# Create placement group for worker processes
|
||||||
current_placement_group = ray.util.get_current_placement_group()
|
current_placement_group = ray.util.get_current_placement_group()
|
||||||
if current_placement_group:
|
if current_placement_group:
|
||||||
# We are in a placement group
|
# We are in a placement group
|
||||||
bundles = current_placement_group.bundle_specs
|
bundles = current_placement_group.bundle_specs
|
||||||
# Verify that we can use the placement group.
|
# Verify that we can use the placement group.
|
||||||
gpu_bundles = 0
|
device_bundles = 0
|
||||||
for bundle in bundles:
|
for bundle in bundles:
|
||||||
bundle_gpus = bundle.get("GPU", 0)
|
bundle_devices = bundle.get(device_str, 0)
|
||||||
if bundle_gpus > 1:
|
if bundle_devices > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Placement group bundle cannot have more than 1 GPU.")
|
"Placement group bundle cannot have more than 1 "
|
||||||
if bundle_gpus:
|
f"{device_str}.")
|
||||||
gpu_bundles += 1
|
if bundle_devices:
|
||||||
if parallel_config.world_size > gpu_bundles:
|
device_bundles += 1
|
||||||
|
if parallel_config.world_size > device_bundles:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of required GPUs exceeds the total number of "
|
f"The number of required {device_str}s exceeds the total "
|
||||||
"available GPUs in the placement group.")
|
f"number of available {device_str}s in the placement group."
|
||||||
|
f"Required number of devices: {parallel_config.world_size}. "
|
||||||
|
f"Total number of devices: {device_bundles}.")
|
||||||
else:
|
else:
|
||||||
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
|
||||||
if parallel_config.world_size > num_gpus_in_cluster:
|
if parallel_config.world_size > num_devices_in_cluster:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of required GPUs exceeds the total number of "
|
f"The number of required {device_str}s exceeds the total "
|
||||||
"available GPUs in the cluster.")
|
f"number of available {device_str}s in the placement group.")
|
||||||
# Create a new placement group
|
# Create a new placement group
|
||||||
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
|
placement_group_specs = ([{
|
||||||
|
device_str: 1
|
||||||
|
}] * parallel_config.world_size)
|
||||||
current_placement_group = ray.util.placement_group(
|
current_placement_group = ray.util.placement_group(
|
||||||
placement_group_specs)
|
placement_group_specs)
|
||||||
# Wait until PG is ready - this will block until all
|
# Wait until PG is ready - this will block until all
|
||||||
|
Loading…
x
Reference in New Issue
Block a user