[V1] Enable multi-input by default (#15799)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-12 16:52:39 +08:00 committed by GitHub
parent f069f3ea74
commit d9fc8cd9da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 214 additions and 105 deletions

View File

@ -759,7 +759,7 @@ On the other hand, modalities separated by `/` are mutually exclusive.
See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model.
:::{important}
To enable multiple multi-modal items per text prompt, you have to set `limit_mm_per_prompt` (offline inference)
**To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference)
or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt:
Offline inference:
@ -777,6 +777,8 @@ Online serving:
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
```
**This is no longer required if you are using vLLM V1.**
:::
:::{note}

View File

@ -110,6 +110,30 @@ If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
#### Disable unused modalities
You can disable unused modalities (except for text) by setting its limit to zero.
For example, if your application only accepts image input, there is no need to allocate any memory for videos.
```python
from vllm import LLM
# Accept images but not videos
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"video": 0})
```
You can even run a multi-modal model for text-only inference:
```python
from vllm import LLM
# Don't accept images. Just text.
llm = LLM(model="google/gemma-3-27b-it",
limit_mm_per_prompt={"image": 0})
```
### Performance optimization and tuning
You can potentially improve the performance of vLLM by finetuning various options.

View File

@ -196,6 +196,11 @@ def main(args):
req_data = model_example_map[model](question_per_audio_count[audio_count],
audio_count)
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)

View File

@ -133,6 +133,11 @@ def main(args):
req_data = model_example_map[model]()
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)

View File

@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=4096,
max_num_seqs=2,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"crop_to_patches": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs(
model="Salesforce/blip2-opt-6.7b",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -129,8 +129,8 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"do_pan_and_scan": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [("<bos><start_of_turn>user\n"
@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code=True,
enforce_eager=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge": 3 * 364
},
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [(
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge": 384
},
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -375,7 +375,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -392,7 +392,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -414,7 +414,7 @@ def run_llava_next_video(questions: list[str],
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -442,7 +442,7 @@ def run_llava_onevision(questions: list[str],
engine_args = EngineArgs(
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -465,7 +465,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
stop_token_ids = [128009]
@ -506,7 +506,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
@ -561,7 +561,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
@ -587,7 +587,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -611,7 +611,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
def run_llama4(questions: list[str], modality: str):
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
@ -621,8 +621,8 @@ def run_llama4(questions: list[str], modality: str):
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
gpu_memory_utilization=0.4,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -657,7 +657,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -683,7 +683,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -710,7 +710,8 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma-3b-mix-224",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
@ -726,7 +727,8 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma2-3b-ft-docci-448",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
@ -762,7 +764,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -793,6 +795,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -813,7 +816,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=6144,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
@ -834,7 +837,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=1024,
max_num_seqs=2,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
@ -859,7 +862,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
if modality == "image":
@ -894,7 +897,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
if modality == "image":
@ -925,7 +928,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -1082,7 +1085,15 @@ def main(args):
req_data = model_example_map[model](questions, modality)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
engine_args = asdict(req_data.engine_args) | {
"seed": args.seed,
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.

View File

@ -63,6 +63,7 @@ def run_e5_v(query: Query) -> ModelRequestData:
model="royokong/e5-v",
task="embed",
max_model_len=4096,
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -93,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
task="embed",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -131,6 +133,11 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
query = get_query(modality)
req_data = model_example_map[model](query)
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)

View File

@ -687,6 +687,11 @@ def run_chat(model: str, question: str, image_urls: list[str],
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)

View File

@ -12,7 +12,9 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
TEST_AUDIO_URLS = [
AudioAsset("winning_call").url,
AudioAsset("mary_had_lamb").url,
]
MAXIMUM_AUDIOS = 2
@pytest.fixture(scope="module")
@ -24,6 +26,8 @@ def server():
"5",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"audio={MAXIMUM_AUDIOS}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]:
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: dict[str, str]):
@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded(
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
async def test_single_chat_session_input_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: dict[str, str]):
@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
@pytest.mark.parametrize(
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str,
base64_encoded_audio: dict[str, str]):
audio_urls: list[str]):
messages = [{
"role":
"user",
"content": [
{
*({
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "input_audio",
"input_audio": {
"data": base64_encoded_audio[audio_url],
"format": "wav"
}
},
} for audio_url in audio_urls),
{
"type": "text",
"text": "What's happening in this audio?"
@ -359,20 +356,30 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
],
}]
with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
if len(audio_urls) > MAXIMUM_AUDIOS:
with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
else:
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0

View File

@ -51,6 +51,10 @@ def run_test(
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
limit_mm_per_prompt = default_limits | limit_mm_per_prompt
vllm_outputs_per_mm = []
hf_outputs_per_mm = []

View File

@ -90,6 +90,7 @@ def test_oot_registration_multimodal(
max_model_len=4096,
enforce_eager=True,
limit_mm_per_prompt={"image": 1})
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

View File

@ -972,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")
exc_ctx = pytest.raises(ValueError, match="The model only supports")
with exc_ctx:
profiler.get_decoder_dummy_data(model_config.max_model_len)
profiler.get_decoder_dummy_data(
model_config.max_model_len,
mm_counts=limit_mm_per_prompt,
)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])

View File

@ -2667,14 +2667,20 @@ class MultiModalConfig:
usedforsecurity=False).hexdigest()
return hash_str
def get_default_limit_per_prompt(self) -> int:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return 999 if envs.VLLM_USE_V1 else 1
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
If not set by the user, this defaults to `1`.
"""
return self.limit_per_prompt.get(modality, 1)
default = self.get_default_limit_per_prompt()
return self.limit_per_prompt.get(modality, default)
# TODO: Add configs to init vision tower or not.

View File

@ -671,13 +671,13 @@ class EngineArgs:
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalConfig.get_limit_per_prompt
# MultiModalConfig.get_default_limit_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
parser.add_argument(
'--mm-processor-kwargs',
default=None,

View File

@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._items_by_modality = defaultdict[str, list[_T]](list)
@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
@staticmethod
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@ -540,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
mm_registry = self.mm_registry
model_config = self.model_config
input_modality = modality.replace("_embeds", "")
if mm_registry.has_processor(model_config):
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
else:
mm_config = model_config.multimodal_config
if mm_config is None:
msg = "This model does not support multi-modal inputs"
raise ValueError(msg)
allowed_count = mm_config.get_limit_per_prompt(input_modality)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item)

View File

@ -126,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(
data,

View File

@ -290,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(
data,
@ -302,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(
data,

View File

@ -720,7 +720,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
@ -734,7 +734,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,

View File

@ -1034,6 +1034,20 @@ class BaseProcessingInfo:
"""
raise NotImplementedError
def get_allowed_mm_limits(self) -> Mapping[str, int]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits = self.get_supported_mm_limits()
mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]()
for modality, supported_limit in supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = (user_limit if supported_limit is None
else min(user_limit, supported_limit))
return allowed_limits
_I = TypeVar("_I", bound=BaseProcessingInfo)
@ -1087,14 +1101,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
before passing them to :meth:`_get_hf_mm_data`.
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.get_mm_config()
supported_mm_limits = self.info.get_supported_mm_limits()
allowed_mm_limits = self.info.get_allowed_mm_limits()
for modality, items in mm_items.items():
limit = mm_config.get_limit_per_prompt(modality)
if len(items) > limit:
supported_limit = supported_mm_limits.get(modality, 0)
allowed_limit = allowed_mm_limits.get(modality, 0)
num_items = len(items)
if supported_limit is not None and num_items > supported_limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but passed {len(items)} "
f"The model only supports at most {supported_limit} "
f"{modality} items, but you passed {num_items} "
f"{modality} items in the same prompt.")
if num_items > allowed_limit:
raise ValueError(
f"You set or defaulted to {modality}={allowed_limit} "
f"in --limit-mm-per-prompt`, but passed {num_items} "
f"{modality} items in the same prompt.")
return mm_items

View File

@ -162,23 +162,7 @@ class MultiModalProfiler(Generic[_I]):
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
return self.processing_info.get_allowed_mm_limits()
def _get_dummy_mm_inputs(
self,

View File

@ -265,8 +265,10 @@ class MultiModalRegistry:
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1
for modality in mm_limits},
{
modality: 1
for modality, limit in mm_limits.items() if limit > 0
},
)
return {

View File

@ -264,7 +264,7 @@ fetch_video = global_media_connector.fetch_video
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
sampling_rate: float,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()