[Hardware][TPU] Refactor TPU backend (#5831)
This commit is contained in:
parent
dd248f7675
commit
bc34937d68
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -26,30 +26,46 @@ class TPUExecutor(ExecutorBase):
|
|||||||
self.model_config.dtype = torch.bfloat16
|
self.model_config.dtype = torch.bfloat16
|
||||||
|
|
||||||
# Instantiate the worker and load the model to the device.
|
# Instantiate the worker and load the model to the device.
|
||||||
self._init_worker()
|
self.driver_worker = self._create_worker()
|
||||||
|
|
||||||
def _init_worker(self):
|
|
||||||
from vllm.worker.tpu_worker import TPUWorker
|
|
||||||
|
|
||||||
assert self.parallel_config.world_size == 1, (
|
|
||||||
"TPUExecutor currently only supports a single TPU chip.")
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
|
||||||
get_ip(), get_open_port())
|
|
||||||
self.driver_worker = TPUWorker(
|
|
||||||
self.model_config,
|
|
||||||
self.parallel_config,
|
|
||||||
self.scheduler_config,
|
|
||||||
self.device_config,
|
|
||||||
self.cache_config,
|
|
||||||
self.load_config,
|
|
||||||
self.vision_language_config,
|
|
||||||
local_rank=0,
|
|
||||||
rank=0,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
)
|
|
||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
|
def _get_worker_kwargs(
|
||||||
|
self,
|
||||||
|
local_rank: int = 0,
|
||||||
|
rank: int = 0,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Return worker init args for a given rank."""
|
||||||
|
if distributed_init_method is None:
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
get_ip(), get_open_port())
|
||||||
|
return dict(
|
||||||
|
model_config=self.model_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
|
is_driver_worker=rank == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_worker(
|
||||||
|
self,
|
||||||
|
local_rank: int = 0,
|
||||||
|
rank: int = 0,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
):
|
||||||
|
from vllm.worker.tpu_worker import TPUWorker
|
||||||
|
|
||||||
|
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
|
||||||
|
distributed_init_method))
|
||||||
|
return worker
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
self,
|
self,
|
||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
|
@ -33,6 +33,7 @@ class TPUModelRunner:
|
|||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -41,6 +42,7 @@ class TPUModelRunner:
|
|||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
self.block_size = self.cache_config.block_size
|
self.block_size = self.cache_config.block_size
|
||||||
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
||||||
@ -373,6 +375,8 @@ class TPUModelRunner:
|
|||||||
inputs = self.prepare_inputs(seq_group_metadata_list)
|
inputs = self.prepare_inputs(seq_group_metadata_list)
|
||||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||||
*inputs[2:])
|
*inputs[2:])
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
|
@ -34,6 +34,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -45,6 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
assert self.device_config.device_type == "tpu"
|
assert self.device_config.device_type == "tpu"
|
||||||
if self.cache_config.cache_dtype == "auto":
|
if self.cache_config.cache_dtype == "auto":
|
||||||
@ -53,10 +55,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
self.cache_config.cache_dtype]
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
self.model_runner = TPUModelRunner(model_config, parallel_config,
|
self.model_runner = TPUModelRunner(model_config,
|
||||||
scheduler_config, device_config,
|
parallel_config,
|
||||||
cache_config, load_config,
|
scheduler_config,
|
||||||
vision_language_config)
|
device_config,
|
||||||
|
cache_config,
|
||||||
|
load_config,
|
||||||
|
vision_language_config,
|
||||||
|
is_driver_worker=is_driver_worker)
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
@ -175,16 +181,13 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
if execute_model_req is None:
|
if not self.is_driver_worker:
|
||||||
return []
|
self._execute_model_non_driver()
|
||||||
|
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
|
||||||
if num_seq_groups == 0:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
assert execute_model_req is not None
|
||||||
# Currently, TPUWorker does not support swapping.
|
# Currently, TPUWorker does not support swapping.
|
||||||
# TODO(woosuk): Support block copying.
|
# TODO(woosuk): Support block copying.
|
||||||
assert len(execute_model_req.blocks_to_swap_in) == 0, (
|
assert len(execute_model_req.blocks_to_swap_in) == 0, (
|
||||||
@ -193,6 +196,16 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
"Swapping is not supported for the TPU backend.")
|
"Swapping is not supported for the TPU backend.")
|
||||||
assert len(execute_model_req.blocks_to_copy) == 0
|
assert len(execute_model_req.blocks_to_copy) == 0
|
||||||
|
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||||
self.tpu_cache)
|
self.tpu_cache)
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
def start_worker_execution_loop(self) -> None:
|
||||||
|
while self._execute_model_non_driver():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _execute_model_non_driver(self) -> bool:
|
||||||
|
self.model_runner.execute_model(None, self.tpu_cache)
|
||||||
|
return True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user