Improve-mm-and-pooler-and-decoding-configs (#16789)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-18 06:13:32 +01:00 committed by GitHub
parent 7eb4255628
commit e78587a64c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 84 additions and 78 deletions

View File

@ -788,7 +788,7 @@ llm = LLM(
Online serving: Online serving:
```bash ```bash
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}'
``` ```
**This is no longer required if you are using vLLM V1.** **This is no longer required if you are using vLLM V1.**

View File

@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
```bash ```bash
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
``` ```
Then, you can use the OpenAI client as follows: Then, you can use the OpenAI client as follows:

View File

@ -16,11 +16,11 @@ from vllm.sampling_params import SamplingParams
# # Mistral format # # Mistral format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --tokenizer-mode mistral --config-format mistral --load-format mistral \ # --tokenizer-mode mistral --config-format mistral --load-format mistral \
# --limit-mm-per-prompt 'image=4' --max-model-len 16384 # --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
# #
# # HF format # # HF format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --limit-mm-per-prompt 'image=4' --max-model-len 16384 # --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
# ``` # ```
# #
# - Client: # - Client:

View File

@ -9,7 +9,7 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
(multi-image inference with Phi-3.5-vision-instruct) (multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
(audio inference with Ultravox) (audio inference with Ultravox)
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096

View File

@ -24,6 +24,10 @@ from vllm.utils import FlexibleArgumentParser
}), }),
]) ])
def test_limit_mm_per_prompt_parser(arg, expected): def test_limit_mm_per_prompt_parser(arg, expected):
"""This functionality is deprecated and will be removed in the future.
This argument should be passed as JSON string instead.
TODO: Remove with nullable_kvs."""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None: if arg is None:
args = parser.parse_args([]) args = parser.parse_args([])

View File

@ -27,7 +27,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
f"audio={MAXIMUM_AUDIOS}", str({"audio": MAXIMUM_AUDIOS}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -31,7 +31,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
f"video={MAXIMUM_VIDEOS}", str({"video": MAXIMUM_VIDEOS}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -35,7 +35,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}", str({"image": MAXIMUM_IMAGES}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -37,7 +37,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}", str({"image": MAXIMUM_IMAGES}),
"--chat-template", "--chat-template",
str(vlm2vec_jinja_path), str(vlm2vec_jinja_path),
] ]

View File

@ -48,9 +48,9 @@ def audio(request):
]) ])
def server(request, audio_assets): def server(request, audio_assets):
args = [ args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}", "--limit-mm-per-prompt",
"--trust-remote-code" str({"audio": len(audio_assets)}), "--trust-remote-code"
] + [ ] + [
f"--{key.replace('_','-')}={value}" f"--{key.replace('_','-')}={value}"
for key, value in request.param.items() for key, value in request.param.items()

View File

@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union) Optional, Protocol, TypeVar, Union, get_args)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@ -2725,6 +2725,7 @@ class PromptAdapterConfig:
self.prompt_adapter_dtype) self.prompt_adapter_dtype)
@config
@dataclass @dataclass
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
@ -2732,6 +2733,8 @@ class MultiModalConfig:
limit_per_prompt: Mapping[str, int] = field(default_factory=dict) limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
""" """
The maximum number of input items allowed per prompt for each modality. The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -2753,24 +2756,20 @@ class MultiModalConfig:
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
return hash_str return hash_str
def get_default_limit_per_prompt(self) -> int:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return 999 if envs.VLLM_USE_V1 else 1
def get_limit_per_prompt(self, modality: str) -> int: def get_limit_per_prompt(self, modality: str) -> int:
""" """
Get the maximum number of input items allowed per prompt Get the maximum number of input items allowed per prompt
for the given modality. for the given modality.
""" """
default = self.get_default_limit_per_prompt() return self.limit_per_prompt.get(
return self.limit_per_prompt.get(modality, default) modality,
999 if envs.VLLM_USE_V1 else 1,
)
# TODO: Add configs to init vision tower or not. # TODO: Add configs to init vision tower or not.
@config
@dataclass @dataclass
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of output pooling in pooling models.""" """Controls the behavior of output pooling in pooling models."""
@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
return served_model_name return served_model_name
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
@config
@dataclass @dataclass
class DecodingConfig: class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine""" """Dataclass which contains the decoding strategy of the engine."""
# Which guided decoding algo to use. guided_decoding_backend: Union[
# 'outlines' / 'lm-format-enforcer' / 'xgrammar' GuidedDecodingBackendV0,
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar" GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""
reasoning_backend: Optional[str] = None reasoning_backend: Optional[str] = None
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.
Required for `--enable-reasoning`."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -3125,17 +3137,12 @@ class DecodingConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
backend = GuidedDecodingParams( backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends valid_guided_backends = get_args(GuidedDecodingBackendV1)
else: else:
valid_guided_backends = v0_valid_guided_backends valid_guided_backends = get_args(GuidedDecodingBackendV0)
if backend not in valid_guided_backends: if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}'," raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}") f" must be one of {valid_guided_backends}")

View File

@ -20,11 +20,12 @@ from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, Device, DeviceConfig, DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides, DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, MultiModalConfig,
ParallelConfig, PoolerConfig, PoolType, ObservabilityConfig, ParallelConfig, PoolerConfig,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, PoolType, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig, SchedulerPolicy, SpeculativeConfig, TaskOption,
VllmConfig, get_attr_docs, get_field) TokenizerPoolConfig, VllmConfig, get_attr_docs,
get_field)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -190,7 +191,8 @@ class EngineArgs:
TokenizerPoolConfig.pool_type TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \ tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config") get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Mapping[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
enable_lora: bool = False enable_lora: bool = False
@ -252,7 +254,7 @@ class EngineArgs:
additional_config: Optional[Dict[str, Any]] = None additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self): def __post_init__(self):
@ -478,18 +480,22 @@ class EngineArgs:
'Examples:\n' 'Examples:\n'
'- 1k → 1000\n' '- 1k → 1000\n'
'- 1K → 1024\n') '- 1K → 1024\n')
parser.add_argument(
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument(
'--guided-decoding-backend', '--guided-decoding-backend',
type=str, **guided_decoding_kwargs["guided_decoding_backend"])
default=DecodingConfig.guided_decoding_backend, guided_decoding_group.add_argument(
help='Which engine will be used for guided decoding' "--reasoning-parser",
' (JSON schema / regex etc) by default. Currently support ' # This choices is a special case because it's not static
'https://github.com/mlc-ai/xgrammar and ' choices=list(ReasoningParserManager.reasoning_parsers),
'https://github.com/guidance-ai/llguidance.' **guided_decoding_kwargs["reasoning_backend"])
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request '
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release.')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=optional_str, type=optional_str,
@ -697,18 +703,14 @@ class EngineArgs:
**tokenizer_kwargs["extra_config"]) **tokenizer_kwargs["extra_config"])
# Multimodal related configs # Multimodal related configs
parser.add_argument( multimodal_kwargs = get_kwargs(MultiModalConfig)
'--limit-mm-per-prompt', multimodal_group = parser.add_argument_group(
type=nullable_kvs, title="MultiModalConfig",
default=EngineArgs.limit_mm_per_prompt, description=MultiModalConfig.__doc__,
# The default value is given in )
# MultiModalConfig.get_default_limit_per_prompt multimodal_group.add_argument('--limit-mm-per-prompt',
help=('For each multimodal plugin, limit how many ' **multimodal_kwargs["limit_per_prompt"])
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
parser.add_argument( parser.add_argument(
'--mm-processor-kwargs', '--mm-processor-kwargs',
default=None, default=None,
@ -1018,16 +1020,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content." "If enabled, the model will be able to generate reasoning content."
) )
parser.add_argument(
"--reasoning-parser",
type=str,
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
parser.add_argument( parser.add_argument(
"--disable-cascade-attn", "--disable-cascade-attn",
action="store_true", action="store_true",

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import re import re
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -1117,8 +1118,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if num_items > allowed_limit: if num_items > allowed_limit:
raise ValueError( raise ValueError(
f"You set or defaulted to {modality}={allowed_limit} " "You set or defaulted to "
f"in --limit-mm-per-prompt`, but passed {num_items} " f"'{json.dumps({modality: allowed_limit})}' in "
f"`--limit-mm-per-prompt`, but passed {num_items} "
f"{modality} items in the same prompt.") f"{modality} items in the same prompt.")
return mm_items return mm_items

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools import functools
import json
from collections import UserDict from collections import UserDict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -194,9 +195,9 @@ class MultiModalRegistry:
max_items = self._limits_by_model[model_config][data_key] max_items = self._limits_by_model[model_config][data_key]
if num_items > max_items: if num_items > max_items:
raise ValueError( raise ValueError(
f"You set {data_key}={max_items} (or defaulted to 1) in " f"You set '{json.dumps({data_key: max_items})}' (or "
f"`--limit-mm-per-prompt`, but found {num_items} items " "defaulted to 1) in `--limit-mm-per-prompt`, but found "
"in the same prompt.") f"{num_items} items in the same prompt.")
input_dict = plugin.map_input(model_config, data_value, input_dict = plugin.map_input(model_config, data_value,
mm_processor_kwargs) mm_processor_kwargs)