Improve-mm-and-pooler-and-decoding-configs (#16789)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
7eb4255628
commit
e78587a64c
@ -788,7 +788,7 @@ llm = LLM(
|
||||
Online serving:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}'
|
||||
```
|
||||
|
||||
**This is no longer required if you are using vLLM V1.**
|
||||
|
@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
|
||||
|
||||
```bash
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||
```
|
||||
|
||||
Then, you can use the OpenAI client as follows:
|
||||
|
@ -16,11 +16,11 @@ from vllm.sampling_params import SamplingParams
|
||||
# # Mistral format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
|
||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
#
|
||||
# # HF format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
# ```
|
||||
#
|
||||
# - Client:
|
||||
|
@ -9,7 +9,7 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
|
||||
(multi-image inference with Phi-3.5-vision-instruct)
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||
|
||||
(audio inference with Ultravox)
|
||||
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
||||
|
@ -24,6 +24,10 @@ from vllm.utils import FlexibleArgumentParser
|
||||
}),
|
||||
])
|
||||
def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
"""This functionality is deprecated and will be removed in the future.
|
||||
This argument should be passed as JSON string instead.
|
||||
|
||||
TODO: Remove with nullable_kvs."""
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
|
@ -27,7 +27,7 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"audio={MAXIMUM_AUDIOS}",
|
||||
str({"audio": MAXIMUM_AUDIOS}),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -31,7 +31,7 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"video={MAXIMUM_VIDEOS}",
|
||||
str({"video": MAXIMUM_VIDEOS}),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -35,7 +35,7 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}",
|
||||
str({"image": MAXIMUM_IMAGES}),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -37,7 +37,7 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}",
|
||||
str({"image": MAXIMUM_IMAGES}),
|
||||
"--chat-template",
|
||||
str(vlm2vec_jinja_path),
|
||||
]
|
||||
|
@ -48,9 +48,9 @@ def audio(request):
|
||||
])
|
||||
def server(request, audio_assets):
|
||||
args = [
|
||||
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
|
||||
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
|
||||
"--trust-remote-code"
|
||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
str({"audio": len(audio_assets)}), "--trust-remote-code"
|
||||
] + [
|
||||
f"--{key.replace('_','-')}={value}"
|
||||
for key, value in request.param.items()
|
||||
|
@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union)
|
||||
Optional, Protocol, TypeVar, Union, get_args)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@ -2725,6 +2725,7 @@ class PromptAdapterConfig:
|
||||
self.prompt_adapter_dtype)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
@ -2732,6 +2733,8 @@ class MultiModalConfig:
|
||||
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
||||
"""
|
||||
The maximum number of input items allowed per prompt for each modality.
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -2753,24 +2756,20 @@ class MultiModalConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_default_limit_per_prompt(self) -> int:
|
||||
"""
|
||||
Return the default number of input items allowed per prompt
|
||||
for any modality if not specified by the user.
|
||||
"""
|
||||
return 999 if envs.VLLM_USE_V1 else 1
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality.
|
||||
"""
|
||||
default = self.get_default_limit_per_prompt()
|
||||
return self.limit_per_prompt.get(modality, default)
|
||||
return self.limit_per_prompt.get(
|
||||
modality,
|
||||
999 if envs.VLLM_USE_V1 else 1,
|
||||
)
|
||||
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
|
||||
return served_model_name
|
||||
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar"]
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine"""
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
# Which guided decoding algo to use.
|
||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
guided_decoding_backend: Union[
|
||||
GuidedDecodingBackendV0,
|
||||
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
is subject to change in each release."""
|
||||
|
||||
reasoning_backend: Optional[str] = None
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format.
|
||||
Required for `--enable-reasoning`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -3125,17 +3137,12 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
v0_valid_guided_backends = [
|
||||
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
|
||||
]
|
||||
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
|
||||
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
if envs.VLLM_USE_V1:
|
||||
valid_guided_backends = v1_valid_guided_backends
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||
else:
|
||||
valid_guided_backends = v0_valid_guided_backends
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||
f" must be one of {valid_guided_backends}")
|
||||
|
@ -20,11 +20,12 @@ from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
|
||||
DecodingConfig, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, HfOverrides,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PoolType,
|
||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PoolType, PromptAdapterConfig, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@ -190,7 +191,8 @@ class EngineArgs:
|
||||
TokenizerPoolConfig.pool_type
|
||||
tokenizer_pool_extra_config: dict[str, Any] = \
|
||||
get_field(TokenizerPoolConfig, "extra_config")
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||
limit_mm_per_prompt: Mapping[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
enable_lora: bool = False
|
||||
@ -252,7 +254,7 @@ class EngineArgs:
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
enable_reasoning: Optional[bool] = None
|
||||
reasoning_parser: Optional[str] = None
|
||||
reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
|
||||
def __post_init__(self):
|
||||
@ -478,18 +480,22 @@ class EngineArgs:
|
||||
'Examples:\n'
|
||||
'- 1k → 1000\n'
|
||||
'- 1K → 1024\n')
|
||||
parser.add_argument(
|
||||
|
||||
# Guided decoding arguments
|
||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||
guided_decoding_group = parser.add_argument_group(
|
||||
title="DecodingConfig",
|
||||
description=DecodingConfig.__doc__,
|
||||
)
|
||||
guided_decoding_group.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
default=DecodingConfig.guided_decoding_backend,
|
||||
help='Which engine will be used for guided decoding'
|
||||
' (JSON schema / regex etc) by default. Currently support '
|
||||
'https://github.com/mlc-ai/xgrammar and '
|
||||
'https://github.com/guidance-ai/llguidance.'
|
||||
'Valid backend values are "xgrammar", "guidance", and "auto". '
|
||||
'With "auto", we will make opinionated choices based on request '
|
||||
'contents and what the backend libraries currently support, so '
|
||||
'the behavior is subject to change in each release.')
|
||||
**guided_decoding_kwargs["guided_decoding_backend"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--reasoning-parser",
|
||||
# This choices is a special case because it's not static
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
**guided_decoding_kwargs["reasoning_backend"])
|
||||
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
type=optional_str,
|
||||
@ -697,18 +703,14 @@ class EngineArgs:
|
||||
**tokenizer_kwargs["extra_config"])
|
||||
|
||||
# Multimodal related configs
|
||||
parser.add_argument(
|
||||
'--limit-mm-per-prompt',
|
||||
type=nullable_kvs,
|
||||
default=EngineArgs.limit_mm_per_prompt,
|
||||
# The default value is given in
|
||||
# MultiModalConfig.get_default_limit_per_prompt
|
||||
help=('For each multimodal plugin, limit how many '
|
||||
'input instances to allow for each prompt. '
|
||||
'Expects a comma-separated list of items, '
|
||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
||||
'images and 2 videos per prompt. Defaults to '
|
||||
'1 (V0) or 999 (V1) for each modality.'))
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
multimodal_group = parser.add_argument_group(
|
||||
title="MultiModalConfig",
|
||||
description=MultiModalConfig.__doc__,
|
||||
)
|
||||
multimodal_group.add_argument('--limit-mm-per-prompt',
|
||||
**multimodal_kwargs["limit_per_prompt"])
|
||||
|
||||
parser.add_argument(
|
||||
'--mm-processor-kwargs',
|
||||
default=None,
|
||||
@ -1018,16 +1020,6 @@ class EngineArgs:
|
||||
"If enabled, the model will be able to generate reasoning content."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser depending on the model that you're "
|
||||
"using. This is used to parse the reasoning content into OpenAI "
|
||||
"API format. Required for ``--enable-reasoning``.")
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-cascade-attn",
|
||||
action="store_true",
|
||||
|
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
@ -1117,8 +1118,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
if num_items > allowed_limit:
|
||||
raise ValueError(
|
||||
f"You set or defaulted to {modality}={allowed_limit} "
|
||||
f"in --limit-mm-per-prompt`, but passed {num_items} "
|
||||
"You set or defaulted to "
|
||||
f"'{json.dumps({modality: allowed_limit})}' in "
|
||||
f"`--limit-mm-per-prompt`, but passed {num_items} "
|
||||
f"{modality} items in the same prompt.")
|
||||
|
||||
return mm_items
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
import json
|
||||
from collections import UserDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -194,9 +195,9 @@ class MultiModalRegistry:
|
||||
max_items = self._limits_by_model[model_config][data_key]
|
||||
if num_items > max_items:
|
||||
raise ValueError(
|
||||
f"You set {data_key}={max_items} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
||||
"in the same prompt.")
|
||||
f"You set '{json.dumps({data_key: max_items})}' (or "
|
||||
"defaulted to 1) in `--limit-mm-per-prompt`, but found "
|
||||
f"{num_items} items in the same prompt.")
|
||||
|
||||
input_dict = plugin.map_input(model_config, data_value,
|
||||
mm_processor_kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user