vllm/cacheflow/worker/controller.py

104 lines
3.2 KiB
Python
Raw Normal View History

2023-05-10 00:58:31 -07:00
from typing import List, Optional, Tuple, Union
2023-03-22 04:45:42 +08:00
try:
import ray
except ImportError:
ray = None
2023-02-23 09:32:19 +00:00
2023-05-09 15:30:12 -07:00
from cacheflow.core.scheduler import Scheduler
2023-02-23 09:32:19 +00:00
from cacheflow.worker.worker import Worker
2023-03-22 04:45:42 +08:00
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
2023-02-23 09:32:19 +00:00
class Controller:
def __init__(
self,
2023-03-22 04:45:42 +08:00
stage_id: int,
stage_devices: List[DeviceID],
world_size: int,
tensor_parallel_size: int,
pipeline_parallel_size: int,
distributed_init_method: str,
2023-02-23 09:32:19 +00:00
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
use_ray: bool,
2023-02-23 09:32:19 +00:00
) -> None:
2023-03-22 04:45:42 +08:00
self.stage_id = stage_id
self.stage_devices = stage_devices
2023-02-23 09:32:19 +00:00
self.model_name = model_name
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.use_ray = use_ray
2023-02-23 09:32:19 +00:00
# Which pipeline stage is this node assigned to?
2023-03-22 04:45:42 +08:00
self.is_first_stage = stage_id == 0
2023-02-23 09:32:19 +00:00
self.is_last_stage = False
self.workers: List[Worker] = []
2023-03-22 04:45:42 +08:00
for rank, node_resource, device_id in stage_devices:
if self.use_ray:
worker_cls = ray.remote(num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5})(Worker).remote
else:
worker_cls = Worker
worker = worker_cls(
2023-02-23 09:32:19 +00:00
model_name=model_name,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
2023-02-23 21:31:39 +00:00
dtype=dtype,
seed=seed,
2023-03-22 04:45:42 +08:00
distributed_init_method=distributed_init_method,
rank=rank,
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
2023-02-23 09:32:19 +00:00
)
self.workers.append(worker)
def set_next(
self,
next_node: Union['Controller', 'Scheduler'],
) -> None:
self.next_node = next_node
self.is_last_stage = isinstance(next_node, Scheduler)
2023-05-10 00:58:31 -07:00
def execute_stage(self, *args, **kwargs) -> None:
all_outputs = []
2023-03-22 04:45:42 +08:00
for worker in self.workers:
executor = (worker.execute_stage.remote
if self.use_ray else worker.execute_stage)
2023-05-10 00:58:31 -07:00
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.use_ray:
all_outputs = ray.get(all_outputs)
2023-03-22 04:45:42 +08:00
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
2023-02-23 09:32:19 +00:00
if self.is_last_stage:
self.next_node.post_step(output)
else:
# TODO: Support pipeline parallelism.
assert False