[Misc] Add --seed
option to offline multi-modal examples (#14934)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
868a8c5b2c
commit
6eaf1e5c52
@ -226,10 +226,13 @@ steps:
|
||||
- python3 offline_inference/basic/chat.py
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/vision_language.py
|
||||
- python3 offline_inference/vision_language_multi_image.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
|
@ -7,11 +7,13 @@ For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -23,21 +25,31 @@ question_per_audio_count = {
|
||||
2: "What sport and what nursery rhyme are referenced?"
|
||||
}
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# MiniCPM-O
|
||||
def run_minicpmo(question: str, audio_count: int):
|
||||
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
llm = LLM(model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int):
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=audio_chat_template)
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# Phi-4-multimodal-instruct
|
||||
def run_phi4mm(questions: str, audio_count: int):
|
||||
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process audio inputs.
|
||||
@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int):
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
|
||||
|
||||
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
|
||||
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int):
|
||||
lora_extra_vocab_size=0,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
lora_request = LoRARequest("speech", 1, speech_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int):
|
||||
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join([
|
||||
f"Audio {idx+1}: "
|
||||
@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int):
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int):
|
||||
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int):
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int):
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, (
|
||||
"Whisper only support single audio input per prompt")
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
|
||||
prompt = "<|startoftranscript|>"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=448,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=448,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
@ -164,14 +194,24 @@ def main(args):
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
audio_count = args.num_audios
|
||||
llm, prompt, stop_token_ids = model_example_map[model](
|
||||
question_per_audio_count[audio_count], audio_count)
|
||||
req_data = model_example_map[model](question_per_audio_count[audio_count],
|
||||
audio_count)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2,
|
||||
max_tokens=64,
|
||||
stop_token_ids=stop_token_ids)
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
@ -183,7 +223,7 @@ def main(args):
|
||||
}
|
||||
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
|
||||
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
@ -214,6 +254,10 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -4,16 +4,23 @@ This example shows how to use vLLM for running offline inference with
|
||||
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
||||
"""
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompts: Sequence[PromptType]
|
||||
|
||||
|
||||
def run_florence2():
|
||||
# Create a Florence-2 encoder/decoder model instance
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
@ -39,12 +46,15 @@ def run_florence2():
|
||||
"decoder_prompt": "",
|
||||
},
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_mllama():
|
||||
# Create a Mllama encoder/decoder model instance
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
@ -69,12 +79,15 @@ def run_mllama():
|
||||
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
|
||||
},
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_whisper():
|
||||
# Create a Whisper encoder/decoder model instance
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
max_model_len=448,
|
||||
max_num_seqs=16,
|
||||
@ -99,7 +112,11 @@ def run_whisper():
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
}
|
||||
]
|
||||
return llm, prompts
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
@ -114,7 +131,12 @@ def main(args):
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
llm, prompts = model_example_map[model]()
|
||||
req_data = model_example_map[model]()
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
prompts = req_data.prompts
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
@ -153,6 +175,10 @@ if __name__ == "__main__":
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -8,122 +8,164 @@ on HuggingFace model repository.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompts: list[str]
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Aria
|
||||
def run_aria(questions: list[str], modality: str):
|
||||
def run_aria(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "rhymes-ai/Aria"
|
||||
|
||||
# NOTE: Need L40 (or equivalent) to avoid OOM
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
for question in questions]
|
||||
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(questions: list[str], modality: str):
|
||||
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="Salesforce/blip2-opt-2.7b",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Chameleon
|
||||
def run_chameleon(questions: list[str], modality: str):
|
||||
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"{question}<image>" for question in questions]
|
||||
llm = LLM(model="facebook/chameleon-7b",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/chameleon-7b",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Deepseek-VL2
|
||||
def run_deepseek_vl2(questions: list[str], modality: str):
|
||||
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
|
||||
for question in questions
|
||||
]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Florence2
|
||||
def run_florence2(question: str, modality: str):
|
||||
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
llm = LLM(model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompt = "<MORE_DETAILED_CAPTION>"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Fuyu
|
||||
def run_fuyu(questions: list[str], modality: str):
|
||||
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"{question}\n" for question in questions]
|
||||
llm = LLM(model="adept/fuyu-8b",
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="adept/fuyu-8b",
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Gemma 3
|
||||
def run_gemma3(questions: list[str], modality: str):
|
||||
def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
@ -135,22 +177,27 @@ def run_gemma3(questions: list[str], modality: str):
|
||||
prompts = [("<bos><start_of_turn>user\n"
|
||||
f"<start_of_image>{question}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n") for question in questions]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# GLM-4v
|
||||
def run_glm4v(questions: list[str], modality: str):
|
||||
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "THUDM/glm-4v-9b"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
@ -158,16 +205,21 @@ def run_glm4v(questions: list[str], modality: str):
|
||||
]
|
||||
|
||||
stop_token_ids = [151329, 151336, 151338]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# H2OVL-Mississippi
|
||||
def run_h2ovl(questions: list[str], modality: str):
|
||||
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "h2oai/h2ovl-mississippi-800m"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
@ -187,15 +239,20 @@ def run_h2ovl(questions: list[str], modality: str):
|
||||
# Stop tokens for H2OVL-Mississippi
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# Idefics3-8B-Llama3
|
||||
def run_idefics3(questions: list[str], modality: str):
|
||||
def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
@ -212,17 +269,20 @@ def run_idefics3(questions: list[str], modality: str):
|
||||
prompts = [(
|
||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||
) for question in questions]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(questions: list[str], modality: str):
|
||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -245,53 +305,75 @@ def run_internvl(questions: list[str], modality: str):
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(questions: list[str], modality: str):
|
||||
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
|
||||
]
|
||||
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# LLaVA-1.6/LLaVA-NeXT
|
||||
def run_llava_next(questions: list[str], modality: str):
|
||||
def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
|
||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# LlaVA-NeXT-Video
|
||||
# Currently only support for video input
|
||||
def run_llava_next_video(questions: list[str], modality: str):
|
||||
def run_llava_next_video(questions: list[str],
|
||||
modality: str) -> ModelRequestData:
|
||||
assert modality == "video"
|
||||
|
||||
prompts = [
|
||||
f"USER: <video>\n{question} ASSISTANT:" for question in questions
|
||||
]
|
||||
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# LLaVA-OneVision
|
||||
def run_llava_onevision(questions: list[str], modality: str):
|
||||
def run_llava_onevision(questions: list[str],
|
||||
modality: str) -> ModelRequestData:
|
||||
|
||||
if modality == "video":
|
||||
prompts = [
|
||||
@ -305,15 +387,20 @@ def run_llava_onevision(questions: list[str], modality: str):
|
||||
<|im_start|>assistant\n" for question in questions
|
||||
]
|
||||
|
||||
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||
max_model_len=16384,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||
max_model_len=16384,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Mantis
|
||||
def run_mantis(questions: list[str], modality: str):
|
||||
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
||||
@ -322,14 +409,19 @@ def run_mantis(questions: list[str], modality: str):
|
||||
for question in questions
|
||||
]
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
max_model_len=4096,
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
stop_token_ids = [128009]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# MiniCPM-V
|
||||
@ -357,7 +449,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
# model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
@ -389,19 +481,24 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
tokenize=False,
|
||||
add_generation_prompt=True) for question in questions
|
||||
]
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def run_minicpmo(questions: list[str], modality: str):
|
||||
def run_minicpmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
|
||||
|
||||
|
||||
def run_minicpmv(questions: list[str], modality: str):
|
||||
def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
|
||||
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
|
||||
|
||||
|
||||
# LLama 3.2
|
||||
def run_mllama(questions: list[str], modality: str):
|
||||
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
@ -411,7 +508,7 @@ def run_mllama(questions: list[str], modality: str):
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
@ -432,17 +529,20 @@ def run_mllama(questions: list[str], modality: str):
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Molmo
|
||||
def run_molmo(questions: list[str], modality: str):
|
||||
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "allenai/Molmo-7B-D-0924"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
@ -453,18 +553,21 @@ def run_molmo(questions: list[str], modality: str):
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n" for question in questions
|
||||
]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# NVLM-D
|
||||
def run_nvlm_d(questions: list[str], modality: str):
|
||||
def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "nvidia/NVLM-D-72B"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -481,36 +584,47 @@ def run_nvlm_d(questions: list[str], modality: str):
|
||||
prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# PaliGemma
|
||||
def run_paligemma(question: str, modality: str):
|
||||
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
# PaliGemma has special prompt format for VQA
|
||||
prompt = ["caption en"]
|
||||
llm = LLM(model="google/paligemma-3b-mix-224",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma-3b-mix-224",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# PaliGemma 2
|
||||
def run_paligemma2(question: str, modality: str):
|
||||
def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
# PaliGemma 2 has special prompt format for VQA
|
||||
prompt = ["caption en"]
|
||||
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma2-3b-ft-docci-448",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Phi-3-Vision
|
||||
def run_phi3v(questions: list[str], modality: str):
|
||||
def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
@ -530,7 +644,7 @@ def run_phi3v(questions: list[str], modality: str):
|
||||
#
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -539,12 +653,15 @@ def run_phi3v(questions: list[str], modality: str):
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Phi-4-multimodal-instruct
|
||||
def run_phi4mm(questions: list[str], modality: str):
|
||||
def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process image inputs.
|
||||
@ -558,7 +675,7 @@ def run_phi4mm(questions: list[str], modality: str):
|
||||
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
|
||||
for question in questions
|
||||
]
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -567,24 +684,22 @@ def run_phi4mm(questions: list[str], modality: str):
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# Pixtral HF-format
|
||||
def run_pixtral_hf(questions: list[str], modality: str):
|
||||
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "mistral-community/pixtral-12b"
|
||||
|
||||
# NOTE: Need L40 (or equivalent) to avoid OOM
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
@ -592,15 +707,18 @@ def run_pixtral_hf(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Qwen
|
||||
def run_qwen_vl(questions: list[str], modality: str):
|
||||
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen-VL",
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
@ -610,16 +728,19 @@ def run_qwen_vl(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(questions: list[str], modality: str):
|
||||
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
@ -642,16 +763,19 @@ def run_qwen2_vl(questions: list[str], modality: str):
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions
|
||||
]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Qwen2.5-VL
|
||||
def run_qwen2_5_vl(questions: list[str], modality: str):
|
||||
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
@ -674,8 +798,11 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions
|
||||
]
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
@ -789,18 +916,28 @@ def main(args):
|
||||
data = mm_input["data"]
|
||||
questions = mm_input["questions"]
|
||||
|
||||
llm, prompts, stop_token_ids = model_example_map[model](questions,
|
||||
modality)
|
||||
req_data = model_example_map[model](questions, modality)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
|
||||
# Don't want to check the flag multiple times, so just hijack `prompts`.
|
||||
prompts = prompts if args.use_different_prompt_per_request else [
|
||||
prompts[0]
|
||||
prompts = req_data.prompts if args.use_different_prompt_per_request else [
|
||||
req_data.prompts[0]
|
||||
]
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2,
|
||||
max_tokens=64,
|
||||
stop_token_ids=stop_token_ids)
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
@ -865,6 +1002,10 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
|
@ -7,11 +7,12 @@ For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -37,12 +38,12 @@ Query = Union[TextQuery, ImageQuery, TextImageQuery]
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
llm: LLM
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
image: Optional[Image]
|
||||
|
||||
|
||||
def run_e5_v(query: Query):
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
|
||||
|
||||
if query["modality"] == "text":
|
||||
@ -58,20 +59,20 @@ def run_e5_v(query: Query):
|
||||
modality = query['modality']
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="royokong/e5-v",
|
||||
task="embed",
|
||||
max_model_len=4096,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_vlm2vec(query: Query):
|
||||
def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
|
||||
@ -87,7 +88,7 @@ def run_vlm2vec(query: Query):
|
||||
modality = query['modality']
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/VLM2Vec-Full",
|
||||
task="embed",
|
||||
trust_remote_code=True,
|
||||
@ -95,7 +96,7 @@ def run_vlm2vec(query: Query):
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
@ -126,15 +127,18 @@ def get_query(modality: QueryModality):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality):
|
||||
def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
mm_data = {}
|
||||
if req_data.image is not None:
|
||||
mm_data["image"] = req_data.image
|
||||
|
||||
outputs = req_data.llm.embed({
|
||||
outputs = llm.embed({
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
@ -144,7 +148,7 @@ def run_encode(model: str, modality: QueryModality):
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality)
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
@ -167,5 +171,10 @@ if __name__ == "__main__":
|
||||
default="image",
|
||||
choices=get_args(QueryModality),
|
||||
help='Modality of the input.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -6,13 +6,14 @@ using the chat template defined by the model.
|
||||
"""
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -25,11 +26,12 @@ IMAGE_URLS = [
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
llm: LLM
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
stop_token_ids: Optional[list[int]]
|
||||
image_data: list[Image]
|
||||
chat_template: Optional[str]
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
chat_template: Optional[str] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
@ -37,53 +39,55 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
def load_aria(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "rhymes-ai/Aria"
|
||||
llm = LLM(model=model_name,
|
||||
tokenizer_mode="slow",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
limit_mm_per_prompt={"image": len(image_urls)})
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
tokenizer_mode="slow",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
|
||||
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_deepseek_vl2(question: str, image_urls: list[str]):
|
||||
def load_deepseek_vl2(question: str,
|
||||
image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
limit_mm_per_prompt={"image": len(image_urls)})
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholder = "".join(f"image_{i}:<image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
@ -112,18 +116,16 @@ def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
|
||||
add_generation_prompt=True)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "h2oai/h2ovl-mississippi-800m"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
@ -146,19 +148,18 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_idefics3(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=16,
|
||||
@ -177,18 +178,16 @@ def load_idefics3(question, image_urls: list[str]) -> ModelRequestData:
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -214,19 +213,18 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_mllama(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=16,
|
||||
@ -236,19 +234,17 @@ def load_mllama(question, image_urls: list[str]) -> ModelRequestData:
|
||||
placeholders = "<|image|>" * len(image_urls)
|
||||
prompt = f"{placeholders}<|begin_of_text|>{question}"
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=None,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_nvlm_d(question: str, image_urls: list[str]):
|
||||
def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "nvidia/NVLM-D-72B"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
@ -266,14 +262,11 @@ def load_nvlm_d(question: str, image_urls: list[str]):
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
@ -281,7 +274,7 @@ def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistral-community/pixtral-12b"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
@ -291,14 +284,11 @@ def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
|
||||
placeholders = "[IMG]" * len(image_urls)
|
||||
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
@ -315,7 +305,7 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
#
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
@ -326,14 +316,11 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
placeholders = "\n".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
@ -347,7 +334,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=10000,
|
||||
@ -357,30 +344,23 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
placeholders = "".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
def load_qwen_vl_chat(question: str,
|
||||
image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
@ -411,7 +391,7 @@ def load_qwen_vl_chat(question: str,
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
@ -419,7 +399,7 @@ def load_qwen_vl_chat(question: str,
|
||||
)
|
||||
|
||||
|
||||
def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
@ -431,7 +411,7 @@ def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
# Tested on L40
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
max_num_seqs=5,
|
||||
@ -460,23 +440,19 @@ def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
stop_token_ids = None
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=image_data,
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
@ -487,7 +463,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
|
||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
max_num_seqs=5,
|
||||
@ -516,8 +492,6 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
stop_token_ids = None
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
@ -525,11 +499,9 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||
return_video_kwargs=False)
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=image_data,
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
@ -551,14 +523,25 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: list[str]):
|
||||
def run_generate(model, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
|
||||
outputs = req_data.llm.generate(
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": {
|
||||
@ -572,13 +555,24 @@ def run_generate(model, question: str, image_urls: list[str]):
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def run_chat(model: str, question: str, image_urls: list[str]):
|
||||
def run_chat(model: str, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
outputs = req_data.llm.chat(
|
||||
outputs = llm.chat(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
@ -607,11 +601,12 @@ def run_chat(model: str, question: str, image_urls: list[str]):
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, IMAGE_URLS)
|
||||
run_generate(model, QUESTION, IMAGE_URLS, seed)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, IMAGE_URLS)
|
||||
run_chat(model, QUESTION, IMAGE_URLS, seed)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
@ -632,6 +627,10 @@ if __name__ == "__main__":
|
||||
default="generate",
|
||||
choices=["generate", "chat"],
|
||||
help="The method to run in `vllm.LLM`.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user