[V1] VLM - enable processor cache by default (#11305)

Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
Alexander Matveev 2024-12-18 18:54:46 -05:00 committed by GitHub
parent ca5f54a9b9
commit fdea8ec167
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 72 additions and 48 deletions

View File

@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = question
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364
},
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = f"<|image|><|begin_of_text|>{question}"
@ -323,7 +323,7 @@ def run_molmo(question, modality):
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = question
@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
# PaliGemma 2 has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM(
model=model_name,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = f"{question}Picture 1: <img></img>\n"
@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
@ -651,9 +651,9 @@ if __name__ == "__main__":
' (if enabled)')
parser.add_argument(
'--mm-cache-preprocessor',
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')
help='If True, disables caching of multi-modal preprocessor/mapper.')
parser.add_argument(
'--time-generate',

View File

@ -148,9 +148,8 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
preprocessor/mapper. Otherwise, the mapper executes each time, and
for better performance consider enabling frontend process.
disable_mm_preprocessor_cache: If true, then disables caching of the
multi-modal preprocessor/mapper. (not recommended)
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
@ -216,7 +215,7 @@ class ModelConfig:
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
@ -286,7 +285,7 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
@ -3155,7 +3154,7 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")

View File

@ -141,7 +141,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
disable_mm_preprocessor_cache: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
@ -606,11 +606,10 @@ class EngineArgs:
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
'--disable-mm-preprocessor-cache',
action='store_true',
help='If true, then enables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
', and for better performance consider enabling frontend process.')
help='If true, then disables caching of the multi-modal '
'preprocessor/mapper. (not recommended)')
# LoRA related configs
parser.add_argument('--enable-lora',
@ -983,7 +982,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)

View File

@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.")
"Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of

View File

@ -43,7 +43,7 @@ class MMInputMapperClient:
self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache
self.use_cache = model_config.mm_cache_preprocessor
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
@ -119,7 +119,7 @@ class MMInputMapperClient:
class MMInputMapperServer:
def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
@ -151,12 +151,26 @@ class MMHasher:
def __init__(self):
pass
def hash(self, prompt: PromptType) -> Optional[List[str]]:
def hash_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
if mm_data is None:
return None
image_inputs = mm_data['image']
return self.hash_images(image_inputs)
def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0

View File

@ -46,7 +46,7 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.use_hash = model_config.mm_cache_preprocessor or \
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
@ -80,7 +80,7 @@ class Processor:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt)
mm_hashes = self.mm_hasher.hash_prompt(prompt)
# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(

View File

@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -79,8 +79,14 @@ class GPUModelRunner:
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self.mm_input_mapper = MMInputMapperClient(self.model_config)
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher()
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
@ -628,9 +634,15 @@ class GPUModelRunner:
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data)
dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=None,
mm_hashes=mm_hashes,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)