[Bugfix] Proper input validation for multi-modal encoder-decoder models (#16156)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
dc96fd54c6
commit
4ebc0b9640
@ -56,7 +56,7 @@ def run_florence2():
|
||||
def run_mllama():
|
||||
engine_args = EngineArgs(
|
||||
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="half",
|
||||
|
@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,8 @@ models = ["llava-hf/llava-1.5-7b-hf"]
|
||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
with pytest.raises(ValueError, match="too long to fit into the model"):
|
||||
with pytest.raises(ValueError,
|
||||
match="longer than the maximum model length"):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
|
@ -15,7 +15,7 @@ def v1(run_with_both_engines):
|
||||
|
||||
def test_empty_prompt():
|
||||
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
|
||||
with pytest.raises(ValueError, match='Prompt cannot be empty'):
|
||||
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
|
||||
llm.generate([""])
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ async def test_empty_prompt():
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match=re.compile('.+Prompt cannot be empty.+')):
|
||||
match="decoder prompt cannot be empty"):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
|
@ -211,7 +211,7 @@ def _run_test(
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=3,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
@ -422,7 +422,7 @@ def test_bnb_regression(
|
||||
llm = LLM(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
quantization="bitsandbytes",
|
||||
)
|
||||
@ -475,7 +475,7 @@ def test_explicit_implicit_prompt(
|
||||
llm = LLM(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
limit_mm_per_prompt={"image":
|
||||
|
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
Iterable, List, Literal, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType)
|
||||
PromptType, SingletonInputs)
|
||||
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
@ -40,6 +40,7 @@ from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -2029,29 +2030,57 @@ class LLMEngine:
|
||||
lora_request: Optional[LoRARequest]):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
if self.model_config.is_multimodal_model:
|
||||
prompt_inputs = decoder_inputs
|
||||
else:
|
||||
prompt_inputs = encoder_inputs or decoder_inputs
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
if prompt_type == "encoder" and self.tokenizer is not None:
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
model_config = self.model_config
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_config, tokenizer=tokenizer)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
if not prompt_ids:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
if self.model_config.is_multimodal_model:
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
raise ValueError(
|
||||
f"The prompt (total length {len(prompt_ids)}) is too long "
|
||||
f"to fit into the model (context length {max_prompt_len}). "
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) >= max_prompt_len:
|
||||
if self.model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
|
@ -213,8 +213,12 @@ class MultiModalProfiler(Generic[_I]):
|
||||
|
||||
total_len = len(encoder_prompt_token_ids)
|
||||
|
||||
# Encoder-decoder multimodal models only support v0
|
||||
if total_len > seq_len:
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
# NOTE: Whisper allows total_len > seq_len.
|
||||
elif total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
# `max_num_batched_tokens` is defined by `SchedulerConfig`
|
||||
logger.warning_once(
|
||||
"The encoder sequence length used for profiling ("
|
||||
@ -229,11 +233,6 @@ class MultiModalProfiler(Generic[_I]):
|
||||
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||
"and/or reduce `mm_counts`.")
|
||||
|
||||
processor = cast(EncDecMultiModalProcessor, self.processor)
|
||||
if processor.pad_dummy_encoder_prompt:
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyEncoderData(encoder_prompt_token_ids)
|
||||
|
||||
def get_decoder_dummy_data(
|
||||
|
@ -2,16 +2,17 @@
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -287,41 +288,62 @@ class Processor:
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
if self.model_config.is_multimodal_model:
|
||||
prompt_inputs = decoder_inputs
|
||||
else:
|
||||
prompt_inputs = encoder_inputs or decoder_inputs
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
if prompt_type == "encoder":
|
||||
model_config = self.model_config
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
model_config, tokenizer=tokenizer)
|
||||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||||
|
||||
if mm_processor.pad_dummy_encoder_prompt:
|
||||
return # Skip encoder length check for Whisper
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
if not prompt_ids:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids)
|
||||
max_allowed = self.tokenizer.get_lora_tokenizer(
|
||||
lora_request).max_token_id
|
||||
if max_input_id > max_allowed:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
if len(prompt_ids) >= self.model_config.max_model_len:
|
||||
raise ValueError(
|
||||
f"Prompt length of {len(prompt_ids)} is longer than the "
|
||||
f"maximum model length of {self.model_config.max_model_len}.")
|
||||
|
||||
if self.model_config.is_multimodal_model:
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
raise ValueError(
|
||||
f"The prompt (total length {len(prompt_ids)}) is too long "
|
||||
f"to fit into the model (context length {max_prompt_len}). "
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) >= max_prompt_len:
|
||||
if self.model_config.is_multimodal_model:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens plus multimodal tokens. For image "
|
||||
"inputs, the number of image tokens depends on the number "
|
||||
"of images, and possibly their aspect ratios as well.")
|
||||
else:
|
||||
suggestion = (
|
||||
"Make sure that `max_model_len` is no smaller than the "
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
|
Loading…
x
Reference in New Issue
Block a user