[Misc] Add --seed option to offline multi-modal examples (#14934)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-17 18:00:17 +08:00 committed by GitHub
parent 868a8c5b2c
commit 6eaf1e5c52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 537 additions and 315 deletions

View File

@ -226,10 +226,13 @@ steps:
- python3 offline_inference/basic/chat.py
- python3 offline_inference/prefix_caching.py
- python3 offline_inference/llm_engine_example.py
- python3 offline_inference/vision_language.py
- python3 offline_inference/vision_language_multi_image.py
- python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_embedding.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py

View File

@ -7,11 +7,13 @@ For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
@ -23,21 +25,31 @@ question_per_audio_count = {
2: "What sport and what nursery rhyme are referenced?"
}
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int):
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
limit_mm_per_prompt={"audio": audio_count},
)
stop_tokens = ['<|im_end|>', '<|endoftext|>']
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True,
chat_template=audio_chat_template)
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
)
# Phi-4-multimodal-instruct
def run_phi4mm(questions: str, audio_count: int):
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int):
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
llm = LLM(
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int):
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int):
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join([
f"Audio {idx+1}: "
@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int):
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int):
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Whisper
def run_whisper(question: str, audio_count: int):
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, (
"Whisper only support single audio input per prompt")
model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=448,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
model_example_map = {
@ -164,14 +194,24 @@ def main(args):
raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios
llm, prompt, stop_token_ids = model_example_map[model](
question_per_audio_count[audio_count], audio_count)
req_data = model_example_map[model](question_per_audio_count[audio_count],
audio_count)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)
stop_token_ids=req_data.stop_token_ids)
mm_data = {}
if audio_count > 0:
@ -183,7 +223,7 @@ def main(args):
}
assert args.num_prompts > 0
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
@ -214,6 +254,10 @@ if __name__ == "__main__":
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)

View File

@ -4,16 +4,23 @@ This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
import time
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, PromptType, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.utils import FlexibleArgumentParser
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompts: Sequence[PromptType]
def run_florence2():
# Create a Florence-2 encoder/decoder model instance
llm = LLM(
engine_args = EngineArgs(
model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
@ -39,12 +46,15 @@ def run_florence2():
"decoder_prompt": "",
},
]
return llm, prompts
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_mllama():
# Create a Mllama encoder/decoder model instance
llm = LLM(
engine_args = EngineArgs(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_model_len=4096,
max_num_seqs=2,
@ -69,12 +79,15 @@ def run_mllama():
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
},
]
return llm, prompts
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_whisper():
# Create a Whisper encoder/decoder model instance
llm = LLM(
engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo",
max_model_len=448,
max_num_seqs=16,
@ -99,7 +112,11 @@ def run_whisper():
"decoder_prompt": "<|startoftranscript|>",
}
]
return llm, prompts
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
model_example_map = {
@ -114,7 +131,12 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompts = model_example_map[model]()
req_data = model_example_map[model]()
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
prompts = req_data.prompts
# Create a sampling params object.
sampling_params = SamplingParams(
@ -153,6 +175,10 @@ if __name__ == "__main__":
default="mllama",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)

View File

@ -8,122 +8,164 @@ on HuggingFace model repository.
"""
import os
import random
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompts: list[str]
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# Aria
def run_aria(questions: list[str], modality: str):
def run_aria(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "rhymes-ai/Aria"
# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n")
for question in questions]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# BLIP-2
def run_blip2(questions: list[str], modality: str):
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
llm = LLM(model="Salesforce/blip2-opt-2.7b",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
engine_args = EngineArgs(
model="Salesforce/blip2-opt-2.7b",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Chameleon
def run_chameleon(questions: list[str], modality: str):
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"{question}<image>" for question in questions]
llm = LLM(model="facebook/chameleon-7b",
engine_args = EngineArgs(
model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Deepseek-VL2
def run_deepseek_vl2(questions: list[str], modality: str):
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "deepseek-ai/deepseek-vl2-tiny"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
)
prompts = [
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
for question in questions
]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Florence2
def run_florence2(question: str, modality: str):
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
llm = LLM(model="microsoft/Florence-2-large",
engine_args = EngineArgs(
model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "<MORE_DETAILED_CAPTION>"
stop_token_ids = None
return llm, prompt, stop_token_ids
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Fuyu
def run_fuyu(questions: list[str], modality: str):
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"{question}\n" for question in questions]
llm = LLM(model="adept/fuyu-8b",
engine_args = EngineArgs(
model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Gemma 3
def run_gemma3(questions: list[str], modality: str):
def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "google/gemma-3-4b-it"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
@ -135,22 +177,27 @@ def run_gemma3(questions: list[str], modality: str):
prompts = [("<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n") for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# GLM-4v
def run_glm4v(questions: list[str], modality: str):
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "THUDM/glm-4v-9b"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
@ -158,16 +205,21 @@ def run_glm4v(questions: list[str], modality: str):
]
stop_token_ids = [151329, 151336, 151338]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str):
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "h2oai/h2ovl-mississippi-800m"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
@ -187,15 +239,20 @@ def run_h2ovl(questions: list[str], modality: str):
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# Idefics3-8B-Llama3
def run_idefics3(questions: list[str], modality: str):
def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
@ -212,17 +269,20 @@ def run_idefics3(questions: list[str], modality: str):
prompts = [(
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
) for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# InternVL
def run_internvl(questions: list[str], modality: str):
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
@ -245,53 +305,75 @@ def run_internvl(questions: list[str], modality: str):
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# LLaVA-1.5
def run_llava(questions: list[str], modality: str):
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
]
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(questions: list[str], modality: str):
def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
engine_args = EngineArgs(
model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(questions: list[str], modality: str):
def run_llava_next_video(questions: list[str],
modality: str) -> ModelRequestData:
assert modality == "video"
prompts = [
f"USER: <video>\n{question} ASSISTANT:" for question in questions
]
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# LLaVA-OneVision
def run_llava_onevision(questions: list[str], modality: str):
def run_llava_onevision(questions: list[str],
modality: str) -> ModelRequestData:
if modality == "video":
prompts = [
@ -305,15 +387,20 @@ def run_llava_onevision(questions: list[str], modality: str):
<|im_start|>assistant\n" for question in questions
]
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
engine_args = EngineArgs(
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompts, stop_token_ids
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Mantis
def run_mantis(questions: list[str], modality: str):
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
@ -322,14 +409,19 @@ def run_mantis(questions: list[str], modality: str):
for question in questions
]
llm = LLM(
engine_args = EngineArgs(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = [128009]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# MiniCPM-V
@ -357,7 +449,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
@ -389,19 +481,24 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
tokenize=False,
add_generation_prompt=True) for question in questions
]
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
def run_minicpmo(questions: list[str], modality: str):
def run_minicpmo(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
def run_minicpmv(questions: list[str], modality: str):
def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
# LLama 3.2
def run_mllama(questions: list[str], modality: str):
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@ -411,7 +508,7 @@ def run_mllama(questions: list[str], modality: str):
# You may lower either to run this example on lower-end GPUs.
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
@ -432,17 +529,20 @@ def run_mllama(questions: list[str], modality: str):
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Molmo
def run_molmo(questions: list[str], modality: str):
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "allenai/Molmo-7B-D-0924"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
@ -453,18 +553,21 @@ def run_molmo(questions: list[str], modality: str):
f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# NVLM-D
def run_nvlm_d(questions: list[str], modality: str):
def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "nvidia/NVLM-D-72B"
# Adjust this as necessary to fit in GPU
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
@ -481,36 +584,47 @@ def run_nvlm_d(questions: list[str], modality: str):
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# PaliGemma
def run_paligemma(question: str, modality: str):
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
# PaliGemma has special prompt format for VQA
prompt = ["caption en"]
llm = LLM(model="google/paligemma-3b-mix-224",
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma-3b-mix-224",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# PaliGemma 2
def run_paligemma2(question: str, modality: str):
def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
# PaliGemma 2 has special prompt format for VQA
prompt = ["caption en"]
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma2-3b-ft-docci-448",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Phi-3-Vision
def run_phi3v(questions: list[str], modality: str):
def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [
@ -530,7 +644,7 @@ def run_phi3v(questions: list[str], modality: str):
#
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
engine_args = EngineArgs(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
@ -539,12 +653,15 @@ def run_phi3v(questions: list[str], modality: str):
mm_processor_kwargs={"num_crops": 16},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Phi-4-multimodal-instruct
def run_phi4mm(questions: list[str], modality: str):
def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
@ -558,7 +675,7 @@ def run_phi4mm(questions: list[str], modality: str):
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
for question in questions
]
llm = LLM(
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
@ -567,24 +684,22 @@ def run_phi4mm(questions: list[str], modality: str):
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
)
# Pixtral HF-format
def run_pixtral_hf(questions: list[str], modality: str):
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "mistral-community/pixtral-12b"
# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
@ -592,15 +707,18 @@ def run_pixtral_hf(questions: list[str], modality: str):
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Qwen
def run_qwen_vl(questions: list[str], modality: str):
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
llm = LLM(
engine_args = EngineArgs(
model="Qwen/Qwen-VL",
trust_remote_code=True,
max_model_len=1024,
@ -610,16 +728,19 @@ def run_qwen_vl(questions: list[str], modality: str):
)
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Qwen2-VL
def run_qwen2_vl(questions: list[str], modality: str):
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
@ -642,16 +763,19 @@ def run_qwen2_vl(questions: list[str], modality: str):
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions
]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Qwen2.5-VL
def run_qwen2_5_vl(questions: list[str], modality: str):
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
@ -674,8 +798,11 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions
]
stop_token_ids = None
return llm, prompts, stop_token_ids
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
model_example_map = {
@ -789,18 +916,28 @@ def main(args):
data = mm_input["data"]
questions = mm_input["questions"]
llm, prompts, stop_token_ids = model_example_map[model](questions,
modality)
req_data = model_example_map[model](questions, modality)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
# Don't want to check the flag multiple times, so just hijack `prompts`.
prompts = prompts if args.use_different_prompt_per_request else [
prompts[0]
prompts = req_data.prompts if args.use_different_prompt_per_request else [
req_data.prompts[0]
]
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)
stop_token_ids=req_data.stop_token_ids)
assert args.num_prompts > 0
if args.num_prompts == 1:
@ -865,6 +1002,10 @@ if __name__ == "__main__":
type=int,
default=16,
help='Number of frames to extract from the video.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
parser.add_argument(
'--image-repeat-prob',

View File

@ -7,11 +7,12 @@ For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from argparse import Namespace
from dataclasses import asdict
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
from PIL.Image import Image
from vllm import LLM
from vllm import LLM, EngineArgs
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
@ -37,12 +38,12 @@ Query = Union[TextQuery, ImageQuery, TextImageQuery]
class ModelRequestData(NamedTuple):
llm: LLM
engine_args: EngineArgs
prompt: str
image: Optional[Image]
def run_e5_v(query: Query):
def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
if query["modality"] == "text":
@ -58,20 +59,20 @@ def run_e5_v(query: Query):
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
engine_args = EngineArgs(
model="royokong/e5-v",
task="embed",
max_model_len=4096,
)
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
image=image,
)
def run_vlm2vec(query: Query):
def run_vlm2vec(query: Query) -> ModelRequestData:
if query["modality"] == "text":
text = query["text"]
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
@ -87,7 +88,7 @@ def run_vlm2vec(query: Query):
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
engine_args = EngineArgs(
model="TIGER-Lab/VLM2Vec-Full",
task="embed",
trust_remote_code=True,
@ -95,7 +96,7 @@ def run_vlm2vec(query: Query):
)
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
image=image,
)
@ -126,15 +127,18 @@ def get_query(modality: QueryModality):
raise ValueError(msg)
def run_encode(model: str, modality: QueryModality):
def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
query = get_query(modality)
req_data = model_example_map[model](query)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
mm_data = {}
if req_data.image is not None:
mm_data["image"] = req_data.image
outputs = req_data.llm.embed({
outputs = llm.embed({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
@ -144,7 +148,7 @@ def run_encode(model: str, modality: QueryModality):
def main(args: Namespace):
run_encode(args.model_name, args.modality)
run_encode(args.model_name, args.modality, args.seed)
model_example_map = {
@ -167,5 +171,10 @@ if __name__ == "__main__":
default="image",
choices=get_args(QueryModality),
help='Modality of the input.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)

View File

@ -6,13 +6,14 @@ using the chat template defined by the model.
"""
import os
from argparse import Namespace
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm import LLM, EngineArgs, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
@ -25,11 +26,12 @@ IMAGE_URLS = [
class ModelRequestData(NamedTuple):
llm: LLM
engine_args: EngineArgs
prompt: str
stop_token_ids: Optional[list[int]]
image_data: list[Image]
chat_template: Optional[str]
stop_token_ids: Optional[list[int]] = None
chat_template: Optional[str] = None
lora_requests: Optional[list[LoRARequest]] = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@ -37,53 +39,55 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.
def load_aria(question, image_urls: list[str]) -> ModelRequestData:
def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "rhymes-ai/Aria"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={"image": len(image_urls)})
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_deepseek_vl2(question: str, image_urls: list[str]):
def load_deepseek_vl2(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny"
llm = LLM(model=model_name,
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={"image": len(image_urls)})
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholder = "".join(f"image_{i}:<image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
@ -112,18 +116,16 @@ def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
add_generation_prompt=True)
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-800m"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
@ -146,19 +148,18 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
stop_token_ids = [tokenizer.eos_token_id]
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_idefics3(question, image_urls: list[str]) -> ModelRequestData:
def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=16,
@ -177,18 +178,16 @@ def load_idefics3(question, image_urls: list[str]) -> ModelRequestData:
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
@ -214,19 +213,18 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_mllama(question, image_urls: list[str]) -> ModelRequestData:
def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
@ -236,19 +234,17 @@ def load_mllama(question, image_urls: list[str]) -> ModelRequestData:
placeholders = "<|image|>" * len(image_urls)
prompt = f"{placeholders}<|begin_of_text|>{question}"
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=None,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_nvlm_d(question: str, image_urls: list[str]):
def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "nvidia/NVLM-D-72B"
# Adjust this as necessary to fit in GPU
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
@ -266,14 +262,11 @@ def load_nvlm_d(question: str, image_urls: list[str]):
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
@ -281,7 +274,7 @@ def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"
# Adjust this as necessary to fit in GPU
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
@ -291,14 +284,11 @@ def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "[IMG]" * len(image_urls)
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
stop_token_ids = None
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
@ -315,7 +305,7 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
#
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
engine_args = EngineArgs(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
@ -326,14 +316,11 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
@ -347,7 +334,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
llm = LLM(
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=10000,
@ -357,30 +344,23 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
placeholders = "".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
stop_token_ids = None
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
)
def load_qwen_vl_chat(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=1024,
@ -411,7 +391,7 @@ def load_qwen_vl_chat(question: str,
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
@ -419,7 +399,7 @@ def load_qwen_vl_chat(question: str,
)
def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
@ -431,7 +411,7 @@ def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct"
# Tested on L40
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=32768 if process_vision_info is None else 4096,
max_num_seqs=5,
@ -460,23 +440,19 @@ def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=image_data,
chat_template=None,
)
def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
@ -487,7 +463,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
llm = LLM(
engine_args = EngineArgs(
model=model_name,
max_model_len=32768 if process_vision_info is None else 4096,
max_num_seqs=5,
@ -516,8 +492,6 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
@ -525,11 +499,9 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
return_video_kwargs=False)
return ModelRequestData(
llm=llm,
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=image_data,
chat_template=None,
)
@ -551,14 +523,25 @@ model_example_map = {
}
def run_generate(model, question: str, image_urls: list[str]):
def run_generate(model, question: str, image_urls: list[str],
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=req_data.stop_token_ids)
outputs = req_data.llm.generate(
outputs = llm.generate(
{
"prompt": req_data.prompt,
"multi_modal_data": {
@ -572,13 +555,24 @@ def run_generate(model, question: str, image_urls: list[str]):
print(generated_text)
def run_chat(model: str, question: str, image_urls: list[str]):
def run_chat(model: str, question: str, image_urls: list[str],
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
# To maintain code compatibility in this script, we add LoRA here.
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
if req_data.lora_requests:
for lora_request in req_data.lora_requests:
llm.llm_engine.add_lora(lora_request=lora_request)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=req_data.stop_token_ids)
outputs = req_data.llm.chat(
outputs = llm.chat(
[{
"role":
"user",
@ -607,11 +601,12 @@ def run_chat(model: str, question: str, image_urls: list[str]):
def main(args: Namespace):
model = args.model_type
method = args.method
seed = args.seed
if method == "generate":
run_generate(model, QUESTION, IMAGE_URLS)
run_generate(model, QUESTION, IMAGE_URLS, seed)
elif method == "chat":
run_chat(model, QUESTION, IMAGE_URLS)
run_chat(model, QUESTION, IMAGE_URLS, seed)
else:
raise ValueError(f"Invalid method: {method}")
@ -632,6 +627,10 @@ if __name__ == "__main__":
default="generate",
choices=["generate", "chat"],
help="The method to run in `vllm.LLM`.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)