From d9fc8cd9da4a69cb4171efb7cb5a46308680c83c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 12 Apr 2025 16:52:39 +0800 Subject: [PATCH] [V1] Enable multi-input by default (#15799) Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.md | 4 +- docs/source/serving/offline_inference.md | 24 ++++++ examples/offline_inference/audio_language.py | 5 ++ .../encoder_decoder_multimodal.py | 5 ++ examples/offline_inference/vision_language.py | 79 +++++++++++-------- .../vision_language_embedding.py | 7 ++ .../vision_language_multi_image.py | 5 ++ tests/entrypoints/openai/test_audio.py | 61 +++++++------- .../vision_language/vlm_utils/core.py | 4 + tests/models/test_oot_registration.py | 1 + tests/multimodal/test_processing.py | 7 +- vllm/config.py | 12 ++- vllm/engine/arg_utils.py | 6 +- vllm/entrypoints/chat_utils.py | 29 +++++-- vllm/model_executor/models/minicpmo.py | 2 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/qwen2_vl.py | 4 +- vllm/multimodal/processing.py | 34 ++++++-- vllm/multimodal/profiling.py | 18 +---- vllm/multimodal/registry.py | 6 +- vllm/multimodal/utils.py | 2 +- 21 files changed, 214 insertions(+), 105 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index c5029d85..ffedd5b0 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -759,7 +759,7 @@ On the other hand, modalities separated by `/` are mutually exclusive. See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model. :::{important} -To enable multiple multi-modal items per text prompt, you have to set `limit_mm_per_prompt` (offline inference) +**To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt: Offline inference: @@ -777,6 +777,8 @@ Online serving: vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 ``` +**This is no longer required if you are using vLLM V1.** + ::: :::{note} diff --git a/docs/source/serving/offline_inference.md b/docs/source/serving/offline_inference.md index 2fa19332..85f2cafa 100644 --- a/docs/source/serving/offline_inference.md +++ b/docs/source/serving/offline_inference.md @@ -110,6 +110,30 @@ If you run out of CPU RAM, try the following options: - (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). +#### Disable unused modalities + +You can disable unused modalities (except for text) by setting its limit to zero. + +For example, if your application only accepts image input, there is no need to allocate any memory for videos. + +```python +from vllm import LLM + +# Accept images but not videos +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"video": 0}) +``` + +You can even run a multi-modal model for text-only inference: + +```python +from vllm import LLM + +# Don't accept images. Just text. +llm = LLM(model="google/gemma-3-27b-it", + limit_mm_per_prompt={"image": 0}) +``` + ### Performance optimization and tuning You can potentially improve the performance of vLLM by finetuning various options. diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 9d758591..24809047 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -196,6 +196,11 @@ def main(args): req_data = model_example_map[model](question_per_audio_count[audio_count], audio_count) + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + req_data.engine_args.limit_mm_per_prompt = default_limits | dict( + req_data.engine_args.limit_mm_per_prompt or {}) + engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index b2f2386d..456ee60e 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -133,6 +133,11 @@ def main(args): req_data = model_example_map[model]() + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + req_data.engine_args.limit_mm_per_prompt = default_limits | dict( + req_data.engine_args.limit_mm_per_prompt or {}) + engine_args = asdict(req_data.engine_args) | {"seed": args.seed} llm = LLM(**engine_args) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 7b587f29..c0799bde 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: max_model_len=4096, max_num_seqs=2, dtype="bfloat16", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [(f"<|im_start|>user\n<|img|>{question}" @@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"crop_to_patches": True}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [ f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" @@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( model="Salesforce/blip2-opt-6.7b", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: model="facebook/chameleon-7b", max_model_len=4096, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -129,8 +129,8 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=4096, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, + limit_mm_per_prompt={"image": 1}, ) prompts = [ @@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="bfloat16", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = ["" for _ in questions] @@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"do_pan_and_scan": True}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [("user\n" @@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, enforce_eager=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [ @@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=8192, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 3 * 364 }, }, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" @@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 384 }, }, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [ (f"<|im_start|>User:{question}\nAssistant:") @@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -375,7 +375,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-1.5-7b-hf", max_model_len=4096, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -392,7 +392,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -414,7 +414,7 @@ def run_llava_next_video(questions: list[str], model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -442,7 +442,7 @@ def run_llava_onevision(questions: list[str], engine_args = EngineArgs( model="llava-hf/llava-onevision-qwen2-7b-ov-hf", max_model_len=16384, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -465,7 +465,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData: model="TIGER-Lab/Mantis-8B-siglip-llama3", max_model_len=4096, hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) stop_token_ids = [128009] @@ -506,7 +506,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): max_model_len=4096, max_num_seqs=2, trust_remote_code=True, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) # NOTE The stop_token_ids are different for various versions of MiniCPM-V # 2.0 @@ -561,7 +561,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=8192, max_num_seqs=2, tensor_parallel_size=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -587,7 +587,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=8192, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -611,7 +611,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: ) -def run_llama4(questions: list[str], modality: str): +def run_llama4(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" @@ -621,8 +621,8 @@ def run_llama4(questions: list[str], modality: str): max_model_len=8192, max_num_seqs=4, tensor_parallel_size=8, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, gpu_memory_utilization=0.4, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -657,7 +657,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, dtype="bfloat16", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [ @@ -683,7 +683,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, max_model_len=4096, tensor_parallel_size=4, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -710,7 +710,8 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma-3b-mix-224", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + limit_mm_per_prompt={"image": 1}, + ) return ModelRequestData( engine_args=engine_args, @@ -726,7 +727,8 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma2-3b-ft-docci-448", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + limit_mm_per_prompt={"image": 1}, + ) return ModelRequestData( engine_args=engine_args, @@ -762,7 +764,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"num_crops": 16}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -793,6 +795,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, enable_lora=True, max_lora_rank=320, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -813,7 +816,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=6144, max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -834,7 +837,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: max_model_len=1024, max_num_seqs=2, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) prompts = [f"{question}Picture 1: \n" for question in questions] @@ -859,7 +862,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: "min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28, }, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) if modality == "image": @@ -894,7 +897,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: "max_pixels": 1280 * 28 * 28, "fps": 1, }, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) if modality == "image": @@ -925,7 +928,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + limit_mm_per_prompt={"image": 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -1082,7 +1085,15 @@ def main(args): req_data = model_example_map[model](questions, modality) - engine_args = asdict(req_data.engine_args) | {"seed": args.seed} + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + req_data.engine_args.limit_mm_per_prompt = default_limits | dict( + req_data.engine_args.limit_mm_per_prompt or {}) + + engine_args = asdict(req_data.engine_args) | { + "seed": args.seed, + "disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, + } llm = LLM(**engine_args) # To maintain code compatibility in this script, we add LoRA here. diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_embedding.py index 8321d3e2..ad3c5ae0 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_embedding.py @@ -63,6 +63,7 @@ def run_e5_v(query: Query) -> ModelRequestData: model="royokong/e5-v", task="embed", max_model_len=4096, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -93,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData: task="embed", trust_remote_code=True, mm_processor_kwargs={"num_crops": 4}, + limit_mm_per_prompt={"image": 1}, ) return ModelRequestData( @@ -131,6 +133,11 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): query = get_query(modality) req_data = model_example_map[model](query) + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + req_data.engine_args.limit_mm_per_prompt = default_limits | dict( + req_data.engine_args.limit_mm_per_prompt or {}) + engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 1ac141d8..7aff5fd0 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -687,6 +687,11 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]): req_data = model_example_map[model](question, image_urls) + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + req_data.engine_args.limit_mm_per_prompt = default_limits | dict( + req_data.engine_args.limit_mm_per_prompt or {}) + engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 3267dcc1..b13002a5 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -12,7 +12,9 @@ from ...utils import RemoteOpenAIServer MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" TEST_AUDIO_URLS = [ AudioAsset("winning_call").url, + AudioAsset("mary_had_lamb").url, ] +MAXIMUM_AUDIOS = 2 @pytest.fixture(scope="module") @@ -24,6 +26,8 @@ def server(): "5", "--enforce-eager", "--trust-remote-code", + "--limit-mm-per-prompt", + f"audio={MAXIMUM_AUDIOS}", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_audio(client: openai.AsyncOpenAI, model_name: str, audio_url: str): messages = [{ @@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_audio_base64encoded( client: openai.AsyncOpenAI, model_name: str, audio_url: str, base64_encoded_audio: dict[str, str]): @@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_input_audio( client: openai.AsyncOpenAI, model_name: str, audio_url: str, base64_encoded_audio: dict[str, str]): @@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) +@pytest.mark.parametrize( + "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]) async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, - audio_url: str, - base64_encoded_audio: dict[str, str]): + audio_urls: list[str]): messages = [{ "role": "user", "content": [ - { + *({ "type": "audio_url", "audio_url": { "url": audio_url } - }, - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, + } for audio_url in audio_urls), { "type": "text", "text": "What's happening in this audio?" @@ -359,20 +356,30 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, ], }] - with pytest.raises(openai.BadRequestError): # test multi-audio input - await client.chat.completions.create( + if len(audio_urls) > MAXIMUM_AUDIOS: + with pytest.raises(openai.BadRequestError): # test multi-audio input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + else: + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_completion_tokens=10, temperature=0.0, ) - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 2eae643f..fd046f3c 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -51,6 +51,10 @@ def run_test( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") + # Disable other modalities to save memory + default_limits = {"image": 0, "video": 0, "audio": 0} + limit_mm_per_prompt = default_limits | limit_mm_per_prompt + vllm_outputs_per_mm = [] hf_outputs_per_mm = [] diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index e6141b97..f1ed8a04 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -90,6 +90,7 @@ def test_oot_registration_multimodal( max_model_len=4096, enforce_eager=True, limit_mm_per_prompt={"image": 1}) + first_token = llm.get_tokenizer().decode(0) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index fa9588a0..59f7bf8f 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -972,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): if is_valid: exc_ctx = nullcontext() else: - exc_ctx = pytest.raises(ValueError, match="this model only supports") + exc_ctx = pytest.raises(ValueError, match="The model only supports") with exc_ctx: - profiler.get_decoder_dummy_data(model_config.max_model_len) + profiler.get_decoder_dummy_data( + model_config.max_model_len, + mm_counts=limit_mm_per_prompt, + ) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) diff --git a/vllm/config.py b/vllm/config.py index d3e224a6..2912361e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2667,14 +2667,20 @@ class MultiModalConfig: usedforsecurity=False).hexdigest() return hash_str + def get_default_limit_per_prompt(self) -> int: + """ + Return the default number of input items allowed per prompt + for any modality if not specified by the user. + """ + return 999 if envs.VLLM_USE_V1 else 1 + def get_limit_per_prompt(self, modality: str) -> int: """ Get the maximum number of input items allowed per prompt for the given modality. - - If not set by the user, this defaults to `1`. """ - return self.limit_per_prompt.get(modality, 1) + default = self.get_default_limit_per_prompt() + return self.limit_per_prompt.get(modality, default) # TODO: Add configs to init vision tower or not. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 70e628ed..975afe5a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -671,13 +671,13 @@ class EngineArgs: type=nullable_kvs, default=EngineArgs.limit_mm_per_prompt, # The default value is given in - # MultiModalConfig.get_limit_per_prompt + # MultiModalConfig.get_default_limit_per_prompt help=('For each multimodal plugin, limit how many ' 'input instances to allow for each prompt. ' 'Expects a comma-separated list of items, ' 'e.g.: `image=16,video=2` allows a maximum of 16 ' - 'images and 2 videos per prompt. Defaults to 1 for ' - 'each modality.')) + 'images and 2 videos per prompt. Defaults to ' + '1 (V0) or 999 (V1) for each modality.')) parser.add_argument( '--mm-processor-kwargs', default=None, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 23c2c3cf..6fb7dc2c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.multimodal import MultiModalDataDict +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._model_config = model_config self._tokenizer = tokenizer - self._allowed_items = (model_config.multimodal_config.limit_per_prompt - if model_config.multimodal_config else {}) self._items_by_modality = defaultdict[str, list[_T]](list) @@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def allowed_local_media_path(self): return self._model_config.allowed_local_media_path + @property + def mm_registry(self): + return MULTIMODAL_REGISTRY + @staticmethod @cache def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: @@ -540,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ - allowed_count = self._allowed_items.get(modality, 1) + mm_registry = self.mm_registry + model_config = self.model_config + + input_modality = modality.replace("_embeds", "") + + if mm_registry.has_processor(model_config): + mm_processor = mm_registry.create_processor(model_config) + allowed_counts = mm_processor.info.get_allowed_mm_limits() + allowed_count = allowed_counts.get(input_modality, 0) + else: + mm_config = model_config.multimodal_config + if mm_config is None: + msg = "This model does not support multi-modal inputs" + raise ValueError(msg) + + allowed_count = mm_config.get_limit_per_prompt(input_modality) + current_count = len(self._items_by_modality[modality]) + 1 if current_count > allowed_count: raise ValueError( f"At most {allowed_count} {modality}(s) may be provided in " - "one request.") + "one request. You can set `--limit-mm-per-prompt` to " + "increase this limit if the model supports it.") self._items_by_modality[modality].append(item) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 29c3cc5e..a2ca92cd 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -126,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], - ) -> ModalityDataItems[Any, Any]: + ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): return MiniCPMOAudioEmbeddingItems( data, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c504737e..1a91cf9b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -290,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> ModalityDataItems[Any, Any]: + ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): return MiniCPMVImageEmbeddingItems( data, @@ -302,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], - ) -> ModalityDataItems[Any, Any]: + ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): return MiniCPMVVideoEmbeddingItems( data, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 11950f78..8c24b8f7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -720,7 +720,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> ModalityDataItems[Any, Any]: + ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -734,7 +734,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], - ) -> ModalityDataItems[Any, Any]: + ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): return DictEmbeddingItems( data, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index f531314a..7f289426 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1034,6 +1034,20 @@ class BaseProcessingInfo: """ raise NotImplementedError + def get_allowed_mm_limits(self) -> Mapping[str, int]: + """Return the maximum allowed number of items for each modality.""" + supported_mm_limits = self.get_supported_mm_limits() + mm_config = self.ctx.get_mm_config() + + allowed_limits = dict[str, int]() + for modality, supported_limit in supported_mm_limits.items(): + user_limit = mm_config.get_limit_per_prompt(modality) + + allowed_limits[modality] = (user_limit if supported_limit is None + else min(user_limit, supported_limit)) + + return allowed_limits + _I = TypeVar("_I", bound=BaseProcessingInfo) @@ -1087,14 +1101,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): before passing them to :meth:`_get_hf_mm_data`. """ mm_items = self.data_parser.parse_mm_data(mm_data) - mm_config = self.info.ctx.get_mm_config() + supported_mm_limits = self.info.get_supported_mm_limits() + allowed_mm_limits = self.info.get_allowed_mm_limits() for modality, items in mm_items.items(): - limit = mm_config.get_limit_per_prompt(modality) - if len(items) > limit: + supported_limit = supported_mm_limits.get(modality, 0) + allowed_limit = allowed_mm_limits.get(modality, 0) + num_items = len(items) + + if supported_limit is not None and num_items > supported_limit: raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but passed {len(items)} " + f"The model only supports at most {supported_limit} " + f"{modality} items, but you passed {num_items} " + f"{modality} items in the same prompt.") + + if num_items > allowed_limit: + raise ValueError( + f"You set or defaulted to {modality}={allowed_limit} " + f"in --limit-mm-per-prompt`, but passed {num_items} " f"{modality} items in the same prompt.") return mm_items diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 29de9b7c..a173487c 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -162,23 +162,7 @@ class MultiModalProfiler(Generic[_I]): return self.processor.dummy_inputs def get_mm_limits(self) -> Mapping[str, int]: - mm_config = self.processing_info.ctx.get_mm_config() - supported_mm_limits = self.processing_info.get_supported_mm_limits() - - mm_limits = { - modality: mm_config.get_limit_per_prompt(modality) - for modality in supported_mm_limits - } - - for modality, supported_limit in supported_mm_limits.items(): - limit = mm_limits[modality] - if supported_limit is not None and supported_limit < limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but this model only supports " - f"at most {supported_limit} {modality} items.") - - return mm_limits + return self.processing_info.get_allowed_mm_limits() def _get_dummy_mm_inputs( self, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index eafa28d6..def05950 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -265,8 +265,10 @@ class MultiModalRegistry: return profiler.get_mm_max_tokens( seq_len, - {modality: 1 - for modality in mm_limits}, + { + modality: 1 + for modality, limit in mm_limits.items() if limit > 0 + }, ) return { diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 77c83f0c..3f9b5be2 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -264,7 +264,7 @@ fetch_video = global_media_connector.fetch_video def encode_audio_base64( audio: np.ndarray, - sampling_rate: int, + sampling_rate: float, ) -> str: """Encode audio as base64.""" audio_io = AudioMediaIO()