
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
242 lines
9.2 KiB
Python
242 lines
9.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from collections.abc import Mapping
|
|
from typing import Optional, Union
|
|
|
|
from typing_extensions import TypeVar
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import ParallelConfig, VllmConfig
|
|
from vllm.engine.arg_utils import EngineArgs
|
|
from vllm.engine.metrics_types import StatLoggerBase
|
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.transformers_utils.tokenizer_group import (
|
|
BaseTokenizerGroup, init_tokenizer_from_configs)
|
|
from vllm.usage.usage_lib import UsageContext
|
|
from vllm.v1.engine.core_client import EngineCoreClient
|
|
from vllm.v1.engine.output_processor import OutputProcessor
|
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
|
from vllm.v1.engine.processor import Processor
|
|
from vllm.v1.executor.abstract import Executor
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
|
|
|
|
|
class LLMEngine:
|
|
"""Legacy LLMEngine for backwards compatibility."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
|
use_cached_outputs: bool = False,
|
|
multiprocess_mode: bool = False,
|
|
) -> None:
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
|
|
# important: init dp group before init the engine_core
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
|
self.should_execute_dummy_batch = False
|
|
if self.dp_enabled:
|
|
self.dp_group = self.parallel_config.stateless_init_dp_group()
|
|
|
|
# Tokenizer (+ ensure liveness if running in another process).
|
|
self.tokenizer = init_tokenizer_from_configs(
|
|
model_config=vllm_config.model_config,
|
|
scheduler_config=vllm_config.scheduler_config,
|
|
parallel_config=vllm_config.parallel_config,
|
|
lora_config=vllm_config.lora_config)
|
|
self.tokenizer.ping()
|
|
|
|
# Processor (convert Inputs --> EngineCoreRequests)
|
|
self.processor = Processor(vllm_config=vllm_config,
|
|
tokenizer=self.tokenizer,
|
|
input_registry=input_registry,
|
|
mm_registry=mm_registry)
|
|
|
|
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
|
self.output_processor = OutputProcessor(self.tokenizer,
|
|
log_stats=False)
|
|
|
|
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
|
self.engine_core = EngineCoreClient.make_client(
|
|
multiprocess_mode=multiprocess_mode,
|
|
asyncio_mode=False,
|
|
vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=False, # FIXME: implement
|
|
)
|
|
|
|
if not multiprocess_mode:
|
|
# for v0 compatibility
|
|
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
|
|
|
@classmethod
|
|
def from_engine_args(
|
|
cls,
|
|
engine_args: EngineArgs,
|
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
|
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
|
enable_multiprocessing: bool = False,
|
|
) -> "LLMEngine":
|
|
"""Creates an LLM engine from the engine arguments."""
|
|
|
|
# Create the engine configs.
|
|
vllm_config = engine_args.create_engine_config(usage_context)
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
|
logger.debug("Enabling multiprocessing for LLMEngine.")
|
|
enable_multiprocessing = True
|
|
|
|
# Create the LLMEngine.
|
|
return cls(vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=not engine_args.disable_log_stats,
|
|
usage_context=usage_context,
|
|
stat_loggers=stat_loggers,
|
|
multiprocess_mode=enable_multiprocessing)
|
|
|
|
def get_num_unfinished_requests(self) -> int:
|
|
return self.output_processor.get_num_unfinished_requests()
|
|
|
|
def has_unfinished_requests(self) -> bool:
|
|
has_unfinished = self.output_processor.has_unfinished_requests()
|
|
if not self.dp_enabled:
|
|
return has_unfinished
|
|
return self.has_unfinished_requests_dp(has_unfinished)
|
|
|
|
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
|
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
|
self.dp_group, has_unfinished)
|
|
if not has_unfinished and aggregated_has_unfinished:
|
|
self.should_execute_dummy_batch = True
|
|
return aggregated_has_unfinished
|
|
|
|
@classmethod
|
|
def validate_outputs(cls, outputs, output_type):
|
|
return outputs
|
|
|
|
def abort_request(self, request_ids: list[str]) -> None:
|
|
"""Remove request_ids from EngineCore and Detokenizer."""
|
|
|
|
self.engine_core.abort_requests(request_ids)
|
|
self.output_processor.abort_requests(request_ids)
|
|
|
|
def add_request(
|
|
self,
|
|
request_id: str,
|
|
prompt: PromptType,
|
|
params: Union[SamplingParams, PoolingParams],
|
|
arrival_time: Optional[float] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
priority: int = 0,
|
|
) -> None:
|
|
# 1) Fan out child requests (for n>1)
|
|
parent_req = ParentRequest.from_params(request_id, params)
|
|
n = params.n if isinstance(params, SamplingParams) else 1
|
|
for idx in range(n):
|
|
if parent_req is not None:
|
|
request_id, params = parent_req.get_child_info(idx)
|
|
|
|
# 2) Process raw inputs into the request.
|
|
request = self.processor.process_inputs(request_id, prompt, params,
|
|
arrival_time, lora_request,
|
|
trace_headers,
|
|
prompt_adapter_request,
|
|
priority)
|
|
|
|
# 3) Make a new RequestState and queue.
|
|
self.output_processor.add_request(request, parent_req, idx)
|
|
|
|
# 3) Add the request to EngineCore.
|
|
self.engine_core.add_request(request)
|
|
|
|
def step(self) -> list[RequestOutput]:
|
|
|
|
if self.should_execute_dummy_batch:
|
|
self.should_execute_dummy_batch = False
|
|
self.engine_core.execute_dummy_batch()
|
|
return []
|
|
|
|
# 1) Get EngineCoreOutput from the EngineCore.
|
|
outputs = self.engine_core.get_output()
|
|
|
|
# 2) Process EngineCoreOutputs.
|
|
processed_outputs = self.output_processor.process_outputs(
|
|
outputs.outputs)
|
|
|
|
# 3) Abort any reqs that finished due to stop strings.
|
|
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
|
|
|
return processed_outputs.request_outputs
|
|
|
|
def get_model_config(self):
|
|
return self.model_config
|
|
|
|
def start_profile(self):
|
|
self.engine_core.profile(True)
|
|
|
|
def stop_profile(self):
|
|
self.engine_core.profile(False)
|
|
|
|
def reset_prefix_cache(self):
|
|
self.engine_core.reset_prefix_cache()
|
|
|
|
def sleep(self, level: int = 1):
|
|
self.engine_core.sleep(level)
|
|
|
|
def wake_up(self):
|
|
self.engine_core.wake_up()
|
|
|
|
def get_tokenizer_group(
|
|
self,
|
|
group_type: type[_G] = BaseTokenizerGroup,
|
|
) -> _G:
|
|
tokenizer_group = self.tokenizer
|
|
|
|
if tokenizer_group is None:
|
|
raise ValueError("Unable to get tokenizer because "
|
|
"skip_tokenizer_init is True")
|
|
if not isinstance(tokenizer_group, group_type):
|
|
raise TypeError("Invalid type of tokenizer group. "
|
|
f"Expected type: {group_type}, but "
|
|
f"found type: {type(tokenizer_group)}")
|
|
|
|
return tokenizer_group
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
"""Load a new LoRA adapter into the engine for future requests."""
|
|
return self.engine_core.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
"""Remove an already loaded LoRA adapter."""
|
|
return self.engine_core.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
"""List all registered adapters."""
|
|
return self.engine_core.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
"""Prevent an adapter from being evicted."""
|
|
return self.engine_core.pin_lora(lora_id)
|