diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0557d0c6..1264e43c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -2,6 +2,7 @@ import enum import time +from collections.abc import Sequence from typing import Any, Optional, Union import msgspec @@ -52,7 +53,7 @@ class EngineCoreRequest( # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: list[int] - mm_inputs: Optional[list[MultiModalKwargs]] + mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: SamplingParams diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f58c77e4..077d4998 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) -from vllm.v1.engine.mm_input_cache import MMInputCacheServer +from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput @@ -105,7 +105,7 @@ class EngineCore: ) # Setup MM Input Mapper. - self.mm_input_cache_server = MMInputCacheServer( + self.mm_input_cache_server = MirroredProcessingCache( vllm_config.model_config) # Setup batch queue for pipeline parallelism. @@ -173,7 +173,7 @@ class EngineCore: # anything that has a hash must have a HIT cache entry here # as well. assert request.mm_inputs is not None - request.mm_inputs = self.mm_input_cache_server.get_and_update( + request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 61a55d24..ef5a2e5a 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import Optional from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.multimodal import MultiModalKwargs from vllm.multimodal.processing import ProcessingCache +from vllm.utils import is_list_of # The idea of multimodal preprocessing caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the @@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache # -- Client: # - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs # with built-in caching functionality, with mm_hash as its identifier. +# - MirroredProcessingCache to keep track of the cached entries and +# determine whether to send the MultiModalKwargs to P1. # # -- Server: -# - MMInputCacheServer to perform caching of the received MultiModalKwargs. +# - MirroredProcessingCache to store the MultiModalKwargs from P0. # # The caching for both client and server is mirrored, and this allows us # to avoid the serialization of "mm_inputs" (like pixel values) between @@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache # variable VLLM_MM_INPUT_CACHE_GIB. -class MMInputCacheServer: +class MirroredProcessingCache: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, MultiModalKwargs) - def get_and_update( + def get_and_update_p0( self, - mm_inputs: list[MultiModalKwargs], + mm_inputs: Sequence[MultiModalKwargs], mm_hashes: list[str], - ) -> list[MultiModalKwargs]: + ) -> Sequence[Optional[MultiModalKwargs]]: assert len(mm_inputs) == len(mm_hashes) if not self.use_cache: + assert is_list_of(mm_inputs, MultiModalKwargs) return mm_inputs - full_mm_inputs = [] + full_mm_inputs = list[Optional[MultiModalKwargs]]() + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + if mm_hash in self.mm_cache: + mm_input = None + else: + self.mm_cache[mm_hash] = mm_input + + full_mm_inputs.append(mm_input) + + return full_mm_inputs + + def get_and_update_p1( + self, + mm_inputs: Sequence[Optional[MultiModalKwargs]], + mm_hashes: list[str], + ) -> Sequence[MultiModalKwargs]: + assert len(mm_inputs) == len(mm_hashes) + + if not self.use_cache: + assert is_list_of(mm_inputs, MultiModalKwargs) + return mm_inputs + + full_mm_inputs = list[MultiModalKwargs]() for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - assert mm_hash is not None if mm_input is None: mm_input = self.mm_cache[mm_hash] else: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index bc5c53b8..5f9c8ea4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Literal, Optional, Union from vllm.config import VllmConfig @@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.utils import ( @@ -47,6 +48,8 @@ class Processor: self.tokenizer, mm_registry) + self.mm_input_cache_client = MirroredProcessingCache(self.model_config) + # Multi-modal hasher (for images) self.use_hash = ( not self.model_config.disable_mm_preprocessor_cache) or \ @@ -231,7 +234,7 @@ class Processor: self.tokenizer.get_lora_tokenizer(lora_request)) # Multimodal related. - sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None + sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_hashes: Optional[list[str]] = None if decoder_inputs["type"] == "multimodal": @@ -256,20 +259,28 @@ class Processor: # are multiple modalities. unique_modalities = set(sorted_item_modalities) if len(unique_modalities) > 1: - sorted_mm_inputs = [] + orig_sorted_mm_inputs = [] used_indices = {modality: 0 for modality in unique_modalities} + for modality in sorted_item_modalities: items = decoder_mm_inputs.get_items(modality) item = items[used_indices[modality]] - sorted_mm_inputs.append(MultiModalKwargs.from_items([item - ])) + + orig_sorted_mm_inputs.append( + MultiModalKwargs.from_items([item])) used_indices[modality] += 1 else: - sorted_mm_inputs = [ + orig_sorted_mm_inputs = [ MultiModalKwargs.from_items([item]) for item in decoder_mm_inputs.get_items(sorted_item_modalities[0]) ] + if sorted_mm_hashes is not None: + sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( + orig_sorted_mm_inputs, sorted_mm_hashes) + else: + sorted_mm_inputs = orig_sorted_mm_inputs + return EngineCoreRequest( request_id=request_id, prompt=decoder_inputs.get("prompt"), diff --git a/vllm/v1/request.py b/vllm/v1/request.py index daf59fd7..6be72431 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,17 +3,16 @@ import enum from typing import TYPE_CHECKING, Optional, Union +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams +from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: - from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.inputs import PlaceholderRange class Request: @@ -23,9 +22,9 @@ class Request: request_id: str, prompt: Optional[str], prompt_token_ids: list[int], - multi_modal_inputs: Optional[list["MultiModalKwargs"]], + multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], - multi_modal_placeholders: Optional[list["PlaceholderRange"]], + multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -75,6 +74,11 @@ class Request: @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + if request.mm_inputs is not None: + assert isinstance(request.mm_inputs, list) + assert is_list_of(request.mm_inputs, MultiModalKwargs), ( + "mm_inputs was not updated in EngineCore.add_request") + return cls( request_id=request.request_id, prompt=request.prompt,