vllm/vllm/v1/engine/llm_engine.py
Tyler Michael Smith 28b3a1c7e5
[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2024-12-10 06:28:14 +00:00

199 lines
7.3 KiB
Python

from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
class LLMEngine:
"""Legacy LLMEngine for backwards compatibility."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
# TODO: Can we avoid this?
self.model_config = vllm_config.model_config
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config)
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry, mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(
tokenizer_name=vllm_config.model_config.tokenizer,
tokenizer_mode=vllm_config.model_config.tokenizer_mode,
trust_remote_code=vllm_config.model_config.trust_remote_code,
revision=vllm_config.model_config.tokenizer_revision,
)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
vllm_config,
executor_class,
usage_context,
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True
# Create the LLMEngine.
return cls(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
return self.detokenizer.has_unfinished_requests()
@classmethod
def validate_outputs(cls, outputs, output_type):
return outputs
def abort_request(self, request_ids: List[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
self.engine_core.abort_requests(request_ids)
self.detokenizer.abort_requests(request_ids)
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
# 1) Process raw inputs into the request.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
# 2) Add the request to Detokenizer.
self.detokenizer.add_request(detokenizer_req)
# 3) Add the request to EngineCore.
self.engine_core.add_request(engine_core_req)
def step(self) -> List[RequestOutput]:
# 1) Get EngineCoreOutput from the EngineCore.
engine_core_outputs = self.engine_core.get_output()
# 2) Detokenizer the EngineCoreOutput.
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)
# 3) Abort requests that finished due to stopping criteria.
if requests_to_abort:
self.abort_request(requests_to_abort)
return request_outputs
# TODO(rob): Can we get rid of these?
def get_model_config(self):
return self.model_config
def start_profile(self):
self.engine_core.profile(True)
def stop_profile(self):
self.engine_core.profile(False)
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group