[V1] Enable multi-input by default (#15799)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f069f3ea74
commit
d9fc8cd9da
@ -759,7 +759,7 @@ On the other hand, modalities separated by `/` are mutually exclusive.
|
||||
See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model.
|
||||
|
||||
:::{important}
|
||||
To enable multiple multi-modal items per text prompt, you have to set `limit_mm_per_prompt` (offline inference)
|
||||
**To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference)
|
||||
or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt:
|
||||
|
||||
Offline inference:
|
||||
@ -777,6 +777,8 @@ Online serving:
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
|
||||
```
|
||||
|
||||
**This is no longer required if you are using vLLM V1.**
|
||||
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
@ -110,6 +110,30 @@ If you run out of CPU RAM, try the following options:
|
||||
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
|
||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||
|
||||
#### Disable unused modalities
|
||||
|
||||
You can disable unused modalities (except for text) by setting its limit to zero.
|
||||
|
||||
For example, if your application only accepts image input, there is no need to allocate any memory for videos.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Accept images but not videos
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
limit_mm_per_prompt={"video": 0})
|
||||
```
|
||||
|
||||
You can even run a multi-modal model for text-only inference:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Don't accept images. Just text.
|
||||
llm = LLM(model="google/gemma-3-27b-it",
|
||||
limit_mm_per_prompt={"image": 0})
|
||||
```
|
||||
|
||||
### Performance optimization and tuning
|
||||
|
||||
You can potentially improve the performance of vLLM by finetuning various options.
|
||||
|
@ -196,6 +196,11 @@ def main(args):
|
||||
req_data = model_example_map[model](question_per_audio_count[audio_count],
|
||||
audio_count)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
@ -133,6 +133,11 @@ def main(args):
|
||||
|
||||
req_data = model_example_map[model]()
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"crop_to_patches": True},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [
|
||||
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="Salesforce/blip2-opt-6.7b",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="facebook/chameleon-7b",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -129,8 +129,8 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
|
||||
@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="adept/fuyu-8b",
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [("<bos><start_of_turn>user\n"
|
||||
@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"longest_edge": 3 * 364
|
||||
},
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [(
|
||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||
@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"longest_edge": 384
|
||||
},
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [
|
||||
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
|
||||
@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -375,7 +375,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -392,7 +392,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -414,7 +414,7 @@ def run_llava_next_video(questions: list[str],
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -442,7 +442,7 @@ def run_llava_onevision(questions: list[str],
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||
max_model_len=16384,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -465,7 +465,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
max_model_len=4096,
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
stop_token_ids = [128009]
|
||||
|
||||
@ -506,7 +506,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
||||
# 2.0
|
||||
@ -561,7 +561,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||
@ -587,7 +587,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -611,7 +611,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_llama4(questions: list[str], modality: str):
|
||||
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
@ -621,8 +621,8 @@ def run_llama4(questions: list[str], modality: str):
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=8,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
gpu_memory_utilization=0.4,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -657,7 +657,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -683,7 +683,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
tensor_parallel_size=4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -710,7 +710,8 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma-3b-mix-224",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -726,7 +727,8 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma2-3b-ft-docci-448",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -762,7 +764,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -793,6 +795,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -813,7 +816,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=6144,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||
@ -834,7 +837,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
|
||||
@ -859,7 +862,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
@ -894,7 +897,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
@ -925,7 +928,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -1082,7 +1085,15 @@ def main(args):
|
||||
|
||||
req_data = model_example_map[model](questions, modality)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {
|
||||
"seed": args.seed,
|
||||
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
|
||||
}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
|
@ -63,6 +63,7 @@ def run_e5_v(query: Query) -> ModelRequestData:
|
||||
model="royokong/e5-v",
|
||||
task="embed",
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -93,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
task="embed",
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -131,6 +133,11 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
@ -687,6 +687,11 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
@ -12,7 +12,9 @@ from ...utils import RemoteOpenAIServer
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
TEST_AUDIO_URLS = [
|
||||
AudioAsset("winning_call").url,
|
||||
AudioAsset("mary_had_lamb").url,
|
||||
]
|
||||
MAXIMUM_AUDIOS = 2
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -24,6 +26,8 @@ def server():
|
||||
"5",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"audio={MAXIMUM_AUDIOS}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -46,7 +50,7 @@ def base64_encoded_audio() -> dict[str, str]:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str):
|
||||
messages = [{
|
||||
@ -100,7 +104,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
@ -158,7 +162,7 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_input_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
@ -330,28 +334,21 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
@pytest.mark.parametrize(
|
||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
|
||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
audio_urls: list[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
*({
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav"
|
||||
}
|
||||
},
|
||||
} for audio_url in audio_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
@ -359,6 +356,7 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
],
|
||||
}]
|
||||
|
||||
if len(audio_urls) > MAXIMUM_AUDIOS:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
@ -376,3 +374,12 @@ async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
else:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
@ -51,6 +51,10 @@ def run_test(
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
limit_mm_per_prompt = default_limits | limit_mm_per_prompt
|
||||
|
||||
vllm_outputs_per_mm = []
|
||||
hf_outputs_per_mm = []
|
||||
|
||||
|
@ -90,6 +90,7 @@ def test_oot_registration_multimodal(
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
first_token = llm.get_tokenizer().decode(0)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
|
@ -972,10 +972,13 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
exc_ctx = pytest.raises(ValueError, match="The model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
profiler.get_decoder_dummy_data(model_config.max_model_len)
|
||||
profiler.get_decoder_dummy_data(
|
||||
model_config.max_model_len,
|
||||
mm_counts=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
|
@ -2667,14 +2667,20 @@ class MultiModalConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_default_limit_per_prompt(self) -> int:
|
||||
"""
|
||||
Return the default number of input items allowed per prompt
|
||||
for any modality if not specified by the user.
|
||||
"""
|
||||
return 999 if envs.VLLM_USE_V1 else 1
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality.
|
||||
|
||||
If not set by the user, this defaults to `1`.
|
||||
"""
|
||||
return self.limit_per_prompt.get(modality, 1)
|
||||
default = self.get_default_limit_per_prompt()
|
||||
return self.limit_per_prompt.get(modality, default)
|
||||
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
|
||||
|
@ -671,13 +671,13 @@ class EngineArgs:
|
||||
type=nullable_kvs,
|
||||
default=EngineArgs.limit_mm_per_prompt,
|
||||
# The default value is given in
|
||||
# MultiModalConfig.get_limit_per_prompt
|
||||
# MultiModalConfig.get_default_limit_per_prompt
|
||||
help=('For each multimodal plugin, limit how many '
|
||||
'input instances to allow for each prompt. '
|
||||
'Expects a comma-separated list of items, '
|
||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
||||
'images and 2 videos per prompt. Defaults to 1 for '
|
||||
'each modality.'))
|
||||
'images and 2 videos per prompt. Defaults to '
|
||||
'1 (V0) or 999 (V1) for each modality.'))
|
||||
parser.add_argument(
|
||||
'--mm-processor-kwargs',
|
||||
default=None,
|
||||
|
@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
|
||||
@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
|
||||
@staticmethod
|
||||
@cache
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
@ -540,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
mm_registry = self.mm_registry
|
||||
model_config = self.model_config
|
||||
|
||||
input_modality = modality.replace("_embeds", "")
|
||||
|
||||
if mm_registry.has_processor(model_config):
|
||||
mm_processor = mm_registry.create_processor(model_config)
|
||||
allowed_counts = mm_processor.info.get_allowed_mm_limits()
|
||||
allowed_count = allowed_counts.get(input_modality, 0)
|
||||
else:
|
||||
mm_config = model_config.multimodal_config
|
||||
if mm_config is None:
|
||||
msg = "This model does not support multi-modal inputs"
|
||||
raise ValueError(msg)
|
||||
|
||||
allowed_count = mm_config.get_limit_per_prompt(input_modality)
|
||||
|
||||
current_count = len(self._items_by_modality[modality]) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
"one request. You can set `--limit-mm-per-prompt` to "
|
||||
"increase this limit if the model supports it.")
|
||||
|
||||
self._items_by_modality[modality].append(item)
|
||||
|
||||
|
@ -126,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMOAudioEmbeddingItems(
|
||||
data,
|
||||
|
@ -290,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMVImageEmbeddingItems(
|
||||
data,
|
||||
@ -302,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return MiniCPMVVideoEmbeddingItems(
|
||||
data,
|
||||
|
@ -720,7 +720,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
@ -734,7 +734,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
|
@ -1034,6 +1034,20 @@ class BaseProcessingInfo:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_allowed_mm_limits(self) -> Mapping[str, int]:
|
||||
"""Return the maximum allowed number of items for each modality."""
|
||||
supported_mm_limits = self.get_supported_mm_limits()
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
|
||||
allowed_limits = dict[str, int]()
|
||||
for modality, supported_limit in supported_mm_limits.items():
|
||||
user_limit = mm_config.get_limit_per_prompt(modality)
|
||||
|
||||
allowed_limits[modality] = (user_limit if supported_limit is None
|
||||
else min(user_limit, supported_limit))
|
||||
|
||||
return allowed_limits
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
@ -1087,14 +1101,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
"""
|
||||
mm_items = self.data_parser.parse_mm_data(mm_data)
|
||||
mm_config = self.info.ctx.get_mm_config()
|
||||
supported_mm_limits = self.info.get_supported_mm_limits()
|
||||
allowed_mm_limits = self.info.get_allowed_mm_limits()
|
||||
|
||||
for modality, items in mm_items.items():
|
||||
limit = mm_config.get_limit_per_prompt(modality)
|
||||
if len(items) > limit:
|
||||
supported_limit = supported_mm_limits.get(modality, 0)
|
||||
allowed_limit = allowed_mm_limits.get(modality, 0)
|
||||
num_items = len(items)
|
||||
|
||||
if supported_limit is not None and num_items > supported_limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but passed {len(items)} "
|
||||
f"The model only supports at most {supported_limit} "
|
||||
f"{modality} items, but you passed {num_items} "
|
||||
f"{modality} items in the same prompt.")
|
||||
|
||||
if num_items > allowed_limit:
|
||||
raise ValueError(
|
||||
f"You set or defaulted to {modality}={allowed_limit} "
|
||||
f"in --limit-mm-per-prompt`, but passed {num_items} "
|
||||
f"{modality} items in the same prompt.")
|
||||
|
||||
return mm_items
|
||||
|
@ -162,23 +162,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
return self.processor.dummy_inputs
|
||||
|
||||
def get_mm_limits(self) -> Mapping[str, int]:
|
||||
mm_config = self.processing_info.ctx.get_mm_config()
|
||||
supported_mm_limits = self.processing_info.get_supported_mm_limits()
|
||||
|
||||
mm_limits = {
|
||||
modality: mm_config.get_limit_per_prompt(modality)
|
||||
for modality in supported_mm_limits
|
||||
}
|
||||
|
||||
for modality, supported_limit in supported_mm_limits.items():
|
||||
limit = mm_limits[modality]
|
||||
if supported_limit is not None and supported_limit < limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but this model only supports "
|
||||
f"at most {supported_limit} {modality} items.")
|
||||
|
||||
return mm_limits
|
||||
return self.processing_info.get_allowed_mm_limits()
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
|
@ -265,8 +265,10 @@ class MultiModalRegistry:
|
||||
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1
|
||||
for modality in mm_limits},
|
||||
{
|
||||
modality: 1
|
||||
for modality, limit in mm_limits.items() if limit > 0
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -264,7 +264,7 @@ fetch_video = global_media_connector.fetch_video
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
|
Loading…
x
Reference in New Issue
Block a user