[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-09 15:51:27 +08:00 committed by GitHub
parent 24f6b9a713
commit e484e02857
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 22 deletions

View File

@ -2,6 +2,7 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[list[MultiModalKwargs]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams

View File

@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
@ -105,7 +105,7 @@ class EngineCore:
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer(
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
@ -173,7 +173,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)

View File

@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
#
# -- Server:
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB.
class MMInputCacheServer:
class MirroredProcessingCache:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update(
def get_and_update_p0(
self,
mm_inputs: list[MultiModalKwargs],
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = []
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union
from vllm.config import VllmConfig
@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
@ -47,6 +48,8 @@ class Processor:
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
@ -231,7 +234,7 @@ class Processor:
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
@ -256,20 +259,28 @@ class Processor:
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),

View File

@ -3,17 +3,16 @@
import enum
from typing import TYPE_CHECKING, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
class Request:
@ -23,9 +22,9 @@ class Request:
request_id: str,
prompt: Optional[str],
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
@ -75,6 +74,11 @@ class Request:
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
return cls(
request_id=request.request_id,
prompt=request.prompt,