[V1] VLM preprocessor hashing (#11020)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Alexander Matveev 2024-12-11 19:55:30 -05:00 committed by GitHub
parent 452a723bf2
commit 4e11683368
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 332 additions and 48 deletions

View File

@ -5,6 +5,8 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import random
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):
prompt = f"USER: <image>\n{question}\nASSISTANT:"
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
<|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384)
max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
assert modality == "image"
prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2)
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")
llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
assert modality == "image"
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096)
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"{question}Picture 1: <img></img>\n"
@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM(
model=model_name,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = f"<|image|><|begin_of_text|>{question}"
@ -355,6 +380,7 @@ def run_molmo(question, modality):
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = question
@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
max_model_len=2048,
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True)
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = question
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364
},
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16")
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
@ -494,6 +524,35 @@ def get_multi_modal_input(args):
raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob]
inputs = []
cur_image = data
for i in range(num_prompts):
if image_repeat_prob is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)
inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})
return inputs
def main(args):
model = args.model_type
if model not in model_example_map:
@ -524,14 +583,29 @@ def main(args):
else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob,
args.num_prompts, data, prompt,
modality)
else:
# Use the same image for all prompts
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params)
if args.time_generate:
import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))
else:
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
@ -561,5 +635,23 @@ if __name__ == "__main__":
type=int,
default=16,
help='Number of frames to extract from the video.')
parser.add_argument(
'--image-repeat-prob',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')
parser.add_argument(
'--time-generate',
action='store_true',
help='If True, then print the total generate() call time')
args = parser.parse_args()
main(args)

View File

@ -3,6 +3,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3.

View File

@ -28,6 +28,7 @@ def make_request() -> EngineCoreRequest:
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(),
eos_token_id=None,

View File

@ -30,6 +30,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=params,
eos_token_id=None,

View File

@ -147,6 +147,9 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
preprocessor/mapper. Otherwise, the mapper executes each time, and
for better performance consider enabling frontend process.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
@ -185,6 +188,7 @@ class ModelConfig:
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
@ -251,6 +255,7 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
@ -2686,9 +2691,10 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r},"
f" compilation_config={self.compilation_config!r}")
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")
_current_vllm_config: Optional[VllmConfig] = None

View File

@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
@ -593,6 +594,12 @@ class EngineArgs:
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If true, then enables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
', and for better performance consider enabling frontend process.')
# LoRA related configs
parser.add_argument('--enable-lora',
@ -965,6 +972,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)

View File

@ -35,7 +35,8 @@ class EngineCoreRequest:
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[MultiModalKwargs]]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]

View File

@ -18,7 +18,7 @@ from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
@ -55,9 +55,6 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
@ -65,6 +62,8 @@ class EngineCore:
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer()
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
start = time.time()
@ -88,7 +87,18 @@ class EngineCore:
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if request.mm_hashes is not None:
# Here, if hash exists for an image, then it will be fetched
# from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]):

View File

@ -1,11 +1,35 @@
from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
from vllm.config import ModelConfig
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
logger = init_logger(__name__)
# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
# Both Client and Server must use the same cache size
# (to perform mirrored caching)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
class MMInputMapper:
class MMInputMapperClient:
def __init__(
self,
@ -18,23 +42,131 @@ class MMInputMapper:
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0
self.mm_cache_total = 0
def cache_hit_ratio(self, steps) -> float:
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
num_inputs = len(image_inputs)
else:
num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None
if use_hash:
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
self.mm_cache_total += 1
if mm_input is None:
if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id]
else:
# Apply MM mapper
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
# Add to cache
self.mm_cache.put(mm_hash, mm_input)
else:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
return ret_inputs, ret_hashes
class MMInputMapperServer:
def __init__(self, ):
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
def process_inputs(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[Optional[str]],
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
else:
self.mm_cache.put(mm_hash, mm_input)
full_mm_inputs.append(mm_input)
return full_mm_inputs
class MMHasher:
def __init__(self):
pass
def hash(self, prompt: PromptType) -> Optional[List[str]]:
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": image_inputs[i]},
mm_processor_kwargs=mm_processor_kwargs,
)
mm_inputs.append(mm_input)
return mm_inputs
ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)
# Convert image to bytes
bytes = image.tobytes()
# Hash image bytes
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
return ret

View File

@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
class Processor:
@ -42,7 +42,11 @@ class Processor:
model_config)
# Multi-modal (huggingface) input mapper
self.mm_input_mapper = MMInputMapper(model_config)
self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
@ -71,6 +75,11 @@ class Processor:
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Compute MM hashes (if enabled)
mm_hashes = None
if self.mm_hasher is not None:
mm_hashes = self.mm_hasher.hash(prompt)
# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
@ -101,16 +110,17 @@ class Processor:
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
# Preprocess multi-modal data
if len(decoder_inputs.multi_modal_data) == 0:
mm_inputs = None
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
mm_inputs = [decoder_inputs.multi_modal_data]
else:
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.mm_processor_kwargs,
)
# For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs = None
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
# Apply MM mapper
mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
# Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest(
@ -130,6 +140,7 @@ class Processor:
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
mm_inputs,
mm_hashes,
decoder_inputs.multi_modal_placeholders,
sampling_params,
eos_token_id,

View File

@ -1,3 +1,4 @@
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Generic, Iterator, List, TypeVar, overload
@ -93,3 +94,23 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
finally:
ctx.destroy(linger=0)
class LRUDictCache:
def __init__(self, size: int):
self.cache = OrderedDict()
self.size = size
def get(self, key, default=None):
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)