[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
24f6b9a713
commit
e484e02857
@ -2,6 +2,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
@ -52,7 +53,7 @@ class EngineCoreRequest(
|
||||
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[list[MultiModalKwargs]]
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
|
@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -105,7 +105,7 @@ class EngineCore:
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
self.mm_input_cache_server = MMInputCacheServer(
|
||||
self.mm_input_cache_server = MirroredProcessingCache(
|
||||
vllm_config.model_config)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
@ -173,7 +173,7 @@ class EngineCore:
|
||||
# anything that has a hash must have a HIT cache entry here
|
||||
# as well.
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update(
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
|
@ -1,8 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
# The idea of multimodal preprocessing caching is based on having a client and
|
||||
# a server, where the client executes in the frontend process (=P0) and the
|
||||
@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
|
||||
# -- Client:
|
||||
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
||||
# with built-in caching functionality, with mm_hash as its identifier.
|
||||
# - MirroredProcessingCache to keep track of the cached entries and
|
||||
# determine whether to send the MultiModalKwargs to P1.
|
||||
#
|
||||
# -- Server:
|
||||
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
|
||||
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
|
||||
#
|
||||
# The caching for both client and server is mirrored, and this allows us
|
||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||
@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
class MMInputCacheServer:
|
||||
class MirroredProcessingCache:
|
||||
|
||||
def __init__(self, model_config):
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
def get_and_update(
|
||||
def get_and_update_p0(
|
||||
self,
|
||||
mm_inputs: list[MultiModalKwargs],
|
||||
mm_inputs: Sequence[MultiModalKwargs],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargs]:
|
||||
) -> Sequence[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = []
|
||||
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if mm_hash in self.mm_cache:
|
||||
mm_input = None
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
def get_and_update_p1(
|
||||
self,
|
||||
mm_inputs: Sequence[Optional[MultiModalKwargs]],
|
||||
mm_hashes: list[str],
|
||||
) -> Sequence[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = list[MultiModalKwargs]()
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
assert mm_hash is not None
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache[mm_hash]
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
@ -47,6 +48,8 @@ class Processor:
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (
|
||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||
@ -231,7 +234,7 @@ class Processor:
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
@ -256,20 +259,28 @@ class Processor:
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
sorted_mm_inputs = []
|
||||
orig_sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
|
||||
]))
|
||||
|
||||
orig_sorted_mm_inputs.append(
|
||||
MultiModalKwargs.from_items([item]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
sorted_mm_inputs = [
|
||||
orig_sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
|
||||
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||
else:
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt=decoder_inputs.get("prompt"),
|
||||
|
@ -3,17 +3,16 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
|
||||
class Request:
|
||||
@ -23,9 +22,9 @@ class Request:
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
|
||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
@ -75,6 +74,11 @@ class Request:
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
if request.mm_inputs is not None:
|
||||
assert isinstance(request.mm_inputs, list)
|
||||
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
|
||||
"mm_inputs was not updated in EngineCore.add_request")
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
prompt=request.prompt,
|
||||
|
Loading…
x
Reference in New Issue
Block a user