[V1] VLM - enable processor cache by default (#11305)
Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
parent
ca5f54a9b9
commit
fdea8ec167
@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
|
|||||||
tokenizer_mode="slow",
|
tokenizer_mode="slow",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
||||||
"<|im_end|>\n<|im_start|>assistant\n")
|
"<|im_end|>\n<|im_start|>assistant\n")
|
||||||
@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
|
|||||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||||
prompt = f"Question: {question} Answer:"
|
prompt = f"Question: {question} Answer:"
|
||||||
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
|
|||||||
prompt = f"{question}<image>"
|
prompt = f"{question}<image>"
|
||||||
llm = LLM(model="facebook/chameleon-7b",
|
llm = LLM(model="facebook/chameleon-7b",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
|
|||||||
llm = LLM(model="adept/fuyu-8b",
|
llm = LLM(model="adept/fuyu-8b",
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
|
|||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
prompt = question
|
prompt = question
|
||||||
stop_token_ids = [151329, 151336, 151338]
|
stop_token_ids = [151329, 151336, 151338]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
|
|||||||
"longest_edge": 3 * 364
|
"longest_edge": 3 * 364
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
prompt = (
|
prompt = (
|
||||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||||
@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):
|
|||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
|
|||||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
prompt = f"[INST] <image>\n{question} [/INST]"
|
||||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
|
|||||||
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
||||||
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):
|
|||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
|
|||||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
stop_token_ids = [128009]
|
stop_token_ids = [128009]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
|
|||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
||||||
# 2.0
|
# 2.0
|
||||||
@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
|
|||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=16,
|
max_num_seqs=16,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"<|image|><|begin_of_text|>{question}"
|
prompt = f"<|image|><|begin_of_text|>{question}"
|
||||||
@ -323,7 +323,7 @@ def run_molmo(question, modality):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = question
|
prompt = question
|
||||||
@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
|
|||||||
# PaliGemma has special prompt format for VQA
|
# PaliGemma has special prompt format for VQA
|
||||||
prompt = "caption en"
|
prompt = "caption en"
|
||||||
llm = LLM(model="google/paligemma-3b-mix-224",
|
llm = LLM(model="google/paligemma-3b-mix-224",
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
|
|||||||
# PaliGemma 2 has special prompt format for VQA
|
# PaliGemma 2 has special prompt format for VQA
|
||||||
prompt = "caption en"
|
prompt = "caption en"
|
||||||
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
|
|||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||||
mm_processor_kwargs={"num_crops": 16},
|
mm_processor_kwargs={"num_crops": 16},
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
|
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
|
||||||
@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"{question}Picture 1: <img></img>\n"
|
prompt = f"{question}Picture 1: <img></img>\n"
|
||||||
@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
|
|||||||
"min_pixels": 28 * 28,
|
"min_pixels": 28 * 28,
|
||||||
"max_pixels": 1280 * 28 * 28,
|
"max_pixels": 1280 * 28 * 28,
|
||||||
},
|
},
|
||||||
mm_cache_preprocessor=args.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
@ -651,9 +651,9 @@ if __name__ == "__main__":
|
|||||||
' (if enabled)')
|
' (if enabled)')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--mm-cache-preprocessor',
|
'--disable-mm-preprocessor-cache',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If True, enable caching of multi-modal preprocessor/mapper.')
|
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--time-generate',
|
'--time-generate',
|
||||||
|
@ -148,9 +148,8 @@ class ModelConfig:
|
|||||||
HuggingFace config.
|
HuggingFace config.
|
||||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||||
for multi-modal data, e.g., image processor.
|
for multi-modal data, e.g., image processor.
|
||||||
mm_cache_preprocessor: If true, then enables caching of the multi-modal
|
disable_mm_preprocessor_cache: If true, then disables caching of the
|
||||||
preprocessor/mapper. Otherwise, the mapper executes each time, and
|
multi-modal preprocessor/mapper. (not recommended)
|
||||||
for better performance consider enabling frontend process.
|
|
||||||
override_neuron_config: Initialize non default neuron config or
|
override_neuron_config: Initialize non default neuron config or
|
||||||
override default neuron config that are specific to Neuron devices,
|
override default neuron config that are specific to Neuron devices,
|
||||||
this argument will be used to configure the neuron config that
|
this argument will be used to configure the neuron config that
|
||||||
@ -216,7 +215,7 @@ class ModelConfig:
|
|||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
mm_cache_preprocessor: bool = False,
|
disable_mm_preprocessor_cache: bool = False,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||||
logits_processor_pattern: Optional[str] = None) -> None:
|
logits_processor_pattern: Optional[str] = None) -> None:
|
||||||
@ -286,7 +285,7 @@ class ModelConfig:
|
|||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
self.use_async_output_proc = use_async_output_proc
|
self.use_async_output_proc = use_async_output_proc
|
||||||
self.mm_processor_kwargs = mm_processor_kwargs
|
self.mm_processor_kwargs = mm_processor_kwargs
|
||||||
self.mm_cache_preprocessor = mm_cache_preprocessor
|
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
|
||||||
|
|
||||||
# Set enforce_eager to False if the value is unset.
|
# Set enforce_eager to False if the value is unset.
|
||||||
if self.enforce_eager is None:
|
if self.enforce_eager is None:
|
||||||
@ -3155,7 +3154,7 @@ class VllmConfig:
|
|||||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||||
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
||||||
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
|
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
|
||||||
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
|
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
|
||||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||||
f"compilation_config={self.compilation_config!r}")
|
f"compilation_config={self.compilation_config!r}")
|
||||||
|
@ -141,7 +141,7 @@ class EngineArgs:
|
|||||||
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
mm_cache_preprocessor: bool = False
|
disable_mm_preprocessor_cache: bool = False
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = False
|
enable_lora_bias: bool = False
|
||||||
max_loras: int = 1
|
max_loras: int = 1
|
||||||
@ -606,11 +606,10 @@ class EngineArgs:
|
|||||||
help=('Overrides for the multimodal input mapping/processing, '
|
help=('Overrides for the multimodal input mapping/processing, '
|
||||||
'e.g., image processor. For example: {"num_crops": 4}.'))
|
'e.g., image processor. For example: {"num_crops": 4}.'))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--mm-cache-preprocessor',
|
'--disable-mm-preprocessor-cache',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If true, then enables caching of the multi-modal '
|
help='If true, then disables caching of the multi-modal '
|
||||||
'preprocessor/mapper. Otherwise, the mapper executes each time'
|
'preprocessor/mapper. (not recommended)')
|
||||||
', and for better performance consider enabling frontend process.')
|
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
parser.add_argument('--enable-lora',
|
parser.add_argument('--enable-lora',
|
||||||
@ -983,7 +982,7 @@ class EngineArgs:
|
|||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
mm_cache_preprocessor=self.mm_cache_preprocessor,
|
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
||||||
override_neuron_config=self.override_neuron_config,
|
override_neuron_config=self.override_neuron_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
logits_processor_pattern=self.logits_processor_pattern)
|
logits_processor_pattern=self.logits_processor_pattern)
|
||||||
|
@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of multi-modal positions and hashes must match. This "
|
"The number of multi-modal positions and hashes must match. This "
|
||||||
"is likely because you do not enable MM preprocessor hashing. "
|
"is likely because you do not enable MM preprocessor hashing. "
|
||||||
"Please set mm_cache_preprocessor=True.")
|
"Please set disable_mm_preprocessor_cache=False.")
|
||||||
|
|
||||||
# Note that we assume mm_positions is sorted by offset.
|
# Note that we assume mm_positions is sorted by offset.
|
||||||
# We do not need to check all mm inputs if the start token index is out of
|
# We do not need to check all mm inputs if the start token index is out of
|
||||||
|
@ -43,7 +43,7 @@ class MMInputMapperClient:
|
|||||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
self.use_cache = model_config.mm_cache_preprocessor
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
# DEBUG: Set to None to disable
|
# DEBUG: Set to None to disable
|
||||||
@ -119,7 +119,7 @@ class MMInputMapperClient:
|
|||||||
class MMInputMapperServer:
|
class MMInputMapperServer:
|
||||||
|
|
||||||
def __init__(self, model_config):
|
def __init__(self, model_config):
|
||||||
self.use_cache = model_config.mm_cache_preprocessor
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
@ -151,12 +151,26 @@ class MMHasher:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def hash(self, prompt: PromptType) -> Optional[List[str]]:
|
def hash_mm_data(
|
||||||
|
self,
|
||||||
|
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
|
||||||
|
if mm_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image_inputs = mm_data['image']
|
||||||
|
|
||||||
|
return self.hash_images(image_inputs)
|
||||||
|
|
||||||
|
def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
|
||||||
if "multi_modal_data" not in prompt:
|
if "multi_modal_data" not in prompt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mm_data = prompt["multi_modal_data"]
|
mm_data = prompt["multi_modal_data"]
|
||||||
image_inputs = mm_data["image"]
|
image_inputs = mm_data["image"]
|
||||||
|
|
||||||
|
return self.hash_images(image_inputs)
|
||||||
|
|
||||||
|
def hash_images(self, image_inputs) -> Optional[List[str]]:
|
||||||
if not isinstance(image_inputs, list):
|
if not isinstance(image_inputs, list):
|
||||||
image_inputs = [image_inputs]
|
image_inputs = [image_inputs]
|
||||||
assert len(image_inputs) > 0
|
assert len(image_inputs) > 0
|
||||||
|
@ -46,7 +46,7 @@ class Processor:
|
|||||||
self.mm_input_mapper_client = MMInputMapperClient(model_config)
|
self.mm_input_mapper_client = MMInputMapperClient(model_config)
|
||||||
|
|
||||||
# Multi-modal hasher (for images)
|
# Multi-modal hasher (for images)
|
||||||
self.use_hash = model_config.mm_cache_preprocessor or \
|
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||||
cache_config.enable_prefix_caching
|
cache_config.enable_prefix_caching
|
||||||
self.mm_hasher = MMHasher()
|
self.mm_hasher = MMHasher()
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ class Processor:
|
|||||||
# Compute MM hashes (if enabled)
|
# Compute MM hashes (if enabled)
|
||||||
mm_hashes = None
|
mm_hashes = None
|
||||||
if self.use_hash:
|
if self.use_hash:
|
||||||
mm_hashes = self.mm_hasher.hash(prompt)
|
mm_hashes = self.mm_hasher.hash_prompt(prompt)
|
||||||
|
|
||||||
# Process inputs.
|
# Process inputs.
|
||||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||||
|
@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
LayerBlockType, cdiv, is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
@ -79,8 +79,14 @@ class GPUModelRunner:
|
|||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.input_registry = INPUT_REGISTRY
|
self.input_registry = INPUT_REGISTRY
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
# NOTE: mm_input_mapper is only used for memory profiling.
|
|
||||||
self.mm_input_mapper = MMInputMapperClient(self.model_config)
|
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
|
||||||
|
# profiling.
|
||||||
|
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
|
||||||
|
self.mm_hasher = MMHasher()
|
||||||
|
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||||
|
cache_config.enable_prefix_caching
|
||||||
|
|
||||||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
|
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
|
||||||
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
|
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
|
||||||
|
|
||||||
@ -628,9 +634,15 @@ class GPUModelRunner:
|
|||||||
mm_registry=self.mm_registry,
|
mm_registry=self.mm_registry,
|
||||||
)
|
)
|
||||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||||
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
|
|
||||||
|
# Compute MM hashes (if enabled)
|
||||||
|
mm_hashes = None
|
||||||
|
if self.use_hash:
|
||||||
|
mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data)
|
||||||
|
|
||||||
|
dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs(
|
||||||
mm_data=dummy_mm_data,
|
mm_data=dummy_mm_data,
|
||||||
mm_hashes=None,
|
mm_hashes=mm_hashes,
|
||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
precomputed_mm_inputs=None)
|
precomputed_mm_inputs=None)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user