diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 6d0c3ac1..b2f2386d 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -56,7 +56,7 @@ def run_florence2(): def run_mllama(): engine_args = EngineArgs( model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_model_len=4096, + max_model_len=8192, max_num_seqs=2, limit_mm_per_prompt={"image": 1}, dtype="half", diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index a944260c..1f3c5757 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: # The configuration below has been confirmed to launch on a single L40 GPU. engine_args = EngineArgs( model=model_name, - max_model_len=4096, + max_model_len=8192, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 39465c9b..23994a61 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData: # The configuration below has been confirmed to launch on a single L40 GPU. engine_args = EngineArgs( model=model_name, - max_model_len=4096, - max_num_seqs=16, + max_model_len=8192, + max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, ) diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index d5111e3f..b29d6362 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -18,7 +18,8 @@ models = ["llava-hf/llava-1.5-7b-hf"] def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] - with pytest.raises(ValueError, match="too long to fit into the model"): + with pytest.raises(ValueError, + match="longer than the maximum model length"): vllm_model = vllm_runner( model, max_model_len=128, # LLaVA has a feature size of 576 diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 61bd1d46..665c6ea1 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -15,7 +15,7 @@ def v1(run_with_both_engines): def test_empty_prompt(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='Prompt cannot be empty'): + with pytest.raises(ValueError, match='decoder prompt cannot be empty'): llm.generate([""]) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 64a1eb6a..f889189a 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -17,7 +17,7 @@ async def test_empty_prompt(): client = remote_server.get_async_client() with pytest.raises(openai.BadRequestError, - match=re.compile('.+Prompt cannot be empty.+')): + match="decoder prompt cannot be empty"): await client.completions.create(model=model_name, prompt="", max_tokens=5, diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index c6886558..a9f0de76 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -211,7 +211,7 @@ def _run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, - max_model_len=4096, + max_model_len=8192, max_num_seqs=3, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, @@ -422,7 +422,7 @@ def test_bnb_regression( llm = LLM( model=model, dtype=dtype, - max_model_len=4096, + max_model_len=8192, max_num_seqs=2, quantization="bitsandbytes", ) @@ -475,7 +475,7 @@ def test_explicit_implicit_prompt( llm = LLM( model=model, dtype=dtype, - max_model_len=4096, + max_model_len=8192, max_num_seqs=2, tensor_parallel_size=1, ) @@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, with global_force_attn_backend_context_manager(attn_backend), vllm_runner( model, dtype=dtype, - max_model_len=4096, + max_model_len=8192, max_num_seqs=2, tensor_parallel_size=1, limit_mm_per_prompt={"image": diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f842581b..3ac39887 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Mapping, NamedTuple, Optional) + Iterable, List, Literal, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast, overload @@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType) + PromptType, SingletonInputs) from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -40,6 +40,7 @@ from vllm.model_executor.guided_decoding import ( get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -2029,29 +2030,57 @@ class LLMEngine: lora_request: Optional[LoRARequest]): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - # For encoder-decoder multimodal models, the max_prompt_len - # restricts the decoder prompt length - if self.model_config.is_multimodal_model: - prompt_inputs = decoder_inputs - else: - prompt_inputs = encoder_inputs or decoder_inputs + if encoder_inputs is not None: + self._validate_model_input(encoder_inputs, + lora_request, + prompt_type="encoder") + + self._validate_model_input(decoder_inputs, + lora_request, + prompt_type="decoder") + + def _validate_model_input( + self, + prompt_inputs: SingletonInputs, + lora_request: Optional[LoRARequest], + *, + prompt_type: Literal["encoder", "decoder"], + ): + if prompt_type == "encoder" and self.tokenizer is not None: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + model_config = self.model_config + + if model_config.is_multimodal_model: + mm_registry = self.input_preprocessor.mm_registry + mm_processor = mm_registry.create_processor( + model_config, tokenizer=tokenizer) + assert isinstance(mm_processor, EncDecMultiModalProcessor) + + if mm_processor.pad_dummy_encoder_prompt: + return # Skip encoder length check for Whisper prompt_ids = prompt_inputs["prompt_token_ids"] - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") + if not prompt_ids: + raise ValueError(f"The {prompt_type} prompt cannot be empty") - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " + max_prompt_len = self.model_config.max_model_len + if len(prompt_ids) >= max_prompt_len: + if self.model_config.is_multimodal_model: + suggestion = ( "Make sure that `max_model_len` is no smaller than the " "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + else: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens.") + + raise ValueError( + f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"longer than the maximum model length of {max_prompt_len}. " + f"{suggestion}") # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 485a90a2..9a733d3b 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -213,8 +213,12 @@ class MultiModalProfiler(Generic[_I]): total_len = len(encoder_prompt_token_ids) - # Encoder-decoder multimodal models only support v0 - if total_len > seq_len: + processor = cast(EncDecMultiModalProcessor, self.processor) + if processor.pad_dummy_encoder_prompt: + num_tokens_to_pad = max(total_len, seq_len) - total_len + encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) + # NOTE: Whisper allows total_len > seq_len. + elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( "The encoder sequence length used for profiling (" @@ -229,11 +233,6 @@ class MultiModalProfiler(Generic[_I]): "increase `max_model_len`, reduce `max_num_seqs`, " "and/or reduce `mm_counts`.") - processor = cast(EncDecMultiModalProcessor, self.processor) - if processor.pad_dummy_encoder_prompt: - num_tokens_to_pad = max(total_len, seq_len) - total_len - encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - return DummyEncoderData(encoder_prompt_token_ids) def get_decoder_dummy_data( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 403edddf..bc5c53b8 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -2,16 +2,17 @@ import time from collections.abc import Mapping -from typing import Optional, Union +from typing import Literal, Optional, Union from vllm.config import VllmConfig -from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -287,41 +288,62 @@ class Processor: lora_request: Optional[LoRARequest] = None): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - # For encoder-decoder multimodal models, the max_prompt_len - # restricts the decoder prompt length - if self.model_config.is_multimodal_model: - prompt_inputs = decoder_inputs - else: - prompt_inputs = encoder_inputs or decoder_inputs + if encoder_inputs is not None: + self._validate_model_input(encoder_inputs, + lora_request, + prompt_type="encoder") + + self._validate_model_input(decoder_inputs, + lora_request, + prompt_type="decoder") + + def _validate_model_input( + self, + prompt_inputs: SingletonInputs, + lora_request: Optional[LoRARequest], + *, + prompt_type: Literal["encoder", "decoder"], + ): + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + + if prompt_type == "encoder": + model_config = self.model_config + + if model_config.is_multimodal_model: + mm_registry = self.input_preprocessor.mm_registry + mm_processor = mm_registry.create_processor( + model_config, tokenizer=tokenizer) + assert isinstance(mm_processor, EncDecMultiModalProcessor) + + if mm_processor.pad_dummy_encoder_prompt: + return # Skip encoder length check for Whisper prompt_ids = prompt_inputs["prompt_token_ids"] - if prompt_ids is None or len(prompt_ids) == 0: - raise ValueError("Prompt cannot be empty") + if not prompt_ids: + raise ValueError(f"The {prompt_type} prompt cannot be empty") max_input_id = max(prompt_ids) - max_allowed = self.tokenizer.get_lora_tokenizer( - lora_request).max_token_id - if max_input_id > max_allowed: - raise ValueError( - "Token id {} is out of vocabulary".format(max_input_id)) + if max_input_id > tokenizer.max_token_id: + raise ValueError(f"Token id {max_input_id} is out of vocabulary") - if len(prompt_ids) >= self.model_config.max_model_len: - raise ValueError( - f"Prompt length of {len(prompt_ids)} is longer than the " - f"maximum model length of {self.model_config.max_model_len}.") - - if self.model_config.is_multimodal_model: - max_prompt_len = self.model_config.max_model_len - - if len(prompt_ids) > max_prompt_len: - raise ValueError( - f"The prompt (total length {len(prompt_ids)}) is too long " - f"to fit into the model (context length {max_prompt_len}). " + max_prompt_len = self.model_config.max_model_len + if len(prompt_ids) >= max_prompt_len: + if self.model_config.is_multimodal_model: + suggestion = ( "Make sure that `max_model_len` is no smaller than the " "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + else: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens.") + + raise ValueError( + f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"longer than the maximum model length of {max_prompt_len}. " + f"{suggestion}") # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them