[Bugfix] Proper input validation for multi-modal encoder-decoder models (#16156)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-09 00:45:21 +08:00 committed by GitHub
parent dc96fd54c6
commit 4ebc0b9640
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 113 additions and 62 deletions

View File

@ -56,7 +56,7 @@ def run_florence2():
def run_mllama():
engine_args = EngineArgs(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": 1},
dtype="half",

View File

@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

View File

@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
)

View File

@ -18,7 +18,8 @@ models = ["llava-hf/llava-1.5-7b-hf"]
def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets]
with pytest.raises(ValueError, match="too long to fit into the model"):
with pytest.raises(ValueError,
match="longer than the maximum model length"):
vllm_model = vllm_runner(
model,
max_model_len=128, # LLaVA has a feature size of 576

View File

@ -15,7 +15,7 @@ def v1(run_with_both_engines):
def test_empty_prompt():
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
llm.generate([""])

View File

@ -17,7 +17,7 @@ async def test_empty_prompt():
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')):
match="decoder prompt cannot be empty"):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,

View File

@ -211,7 +211,7 @@ def _run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
@ -422,7 +422,7 @@ def test_bnb_regression(
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
quantization="bitsandbytes",
)
@ -475,7 +475,7 @@ def test_explicit_implicit_prompt(
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=1,
)
@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model,
dtype=dtype,
max_model_len=4096,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=1,
limit_mm_per_prompt={"image":

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
Iterable, List, Literal, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
PromptType, SingletonInputs)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
@ -40,6 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
@ -2029,29 +2030,57 @@ class LLMEngine:
lora_request: Optional[LoRARequest]):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
if prompt_type == "encoder" and self.tokenizer is not None:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
model_config = self.model_config
if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
if not prompt_ids:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them

View File

@ -213,8 +213,12 @@ class MultiModalProfiler(Generic[_I]):
total_len = len(encoder_prompt_token_ids)
# Encoder-decoder multimodal models only support v0
if total_len > seq_len:
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The encoder sequence length used for profiling ("
@ -229,11 +233,6 @@ class MultiModalProfiler(Generic[_I]):
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.")
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data(

View File

@ -2,16 +2,17 @@
import time
from collections.abc import Mapping
from typing import Optional, Union
from typing import Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -287,41 +288,62 @@ class Processor:
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
if prompt_type == "encoder":
model_config = self.model_config
if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
if not prompt_ids:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the "
f"maximum model length of {self.model_config.max_model_len}.")
if self.model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). "
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them