2024-07-26 22:44:13 -07:00
|
|
|
"""
|
2024-10-23 11:35:29 +08:00
|
|
|
This example shows how to use vLLM for running offline inference with
|
|
|
|
the correct prompt format on vision language models for text generation.
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
For most models, the prompt format should follow corresponding examples
|
|
|
|
on HuggingFace model repository.
|
|
|
|
"""
|
2024-12-11 19:55:30 -05:00
|
|
|
import random
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
from vllm.assets.image import ImageAsset
|
2024-09-11 13:21:36 +08:00
|
|
|
from vllm.assets.video import VideoAsset
|
2024-07-26 22:44:13 -07:00
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
2024-09-29 00:54:35 +08:00
|
|
|
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
|
|
|
# lower-end GPUs.
|
|
|
|
# Unless specified, these settings have been tested to work on a single L4.
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Aria
|
|
|
|
def run_aria(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
model_name = "rhymes-ai/Aria"
|
|
|
|
|
|
|
|
llm = LLM(model=model_name,
|
|
|
|
tokenizer_mode="slow",
|
|
|
|
trust_remote_code=True,
|
|
|
|
dtype="bfloat16",
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-12-16 19:23:33 +08:00
|
|
|
|
|
|
|
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
|
|
|
"<|im_end|>\n<|im_start|>assistant\n")
|
|
|
|
|
|
|
|
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# BLIP-2
|
|
|
|
def run_blip2(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
|
|
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
|
|
|
prompt = f"Question: {question} Answer:"
|
|
|
|
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-12-16 19:23:33 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# Chameleon
|
|
|
|
def run_chameleon(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
prompt = f"{question}<image>"
|
|
|
|
llm = LLM(model="facebook/chameleon-7b",
|
|
|
|
max_model_len=4096,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-12-16 19:23:33 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# Fuyu
|
|
|
|
def run_fuyu(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
prompt = f"{question}\n"
|
|
|
|
llm = LLM(model="adept/fuyu-8b",
|
|
|
|
max_model_len=2048,
|
|
|
|
max_num_seqs=2,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-12-16 19:23:33 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# GLM-4v
|
|
|
|
def run_glm4v(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
model_name = "THUDM/glm-4v-9b"
|
|
|
|
|
|
|
|
llm = LLM(model=model_name,
|
|
|
|
max_model_len=2048,
|
|
|
|
max_num_seqs=2,
|
|
|
|
trust_remote_code=True,
|
|
|
|
enforce_eager=True,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-12-16 19:23:33 +08:00
|
|
|
prompt = question
|
|
|
|
stop_token_ids = [151329, 151336, 151338]
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# H2OVL-Mississippi
|
|
|
|
def run_h2ovl(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
model_name = "h2oai/h2ovl-mississippi-2b"
|
|
|
|
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
|
|
|
trust_remote_code=True,
|
|
|
|
max_model_len=8192,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-12-16 19:23:33 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
|
|
trust_remote_code=True)
|
|
|
|
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
|
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True)
|
|
|
|
|
|
|
|
# Stop tokens for H2OVL-Mississippi
|
|
|
|
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
|
|
|
|
stop_token_ids = [tokenizer.eos_token_id]
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# Idefics3-8B-Llama3
|
|
|
|
def run_idefics3(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
|
|
|
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
|
|
|
max_model_len=8192,
|
|
|
|
max_num_seqs=2,
|
|
|
|
enforce_eager=True,
|
|
|
|
# if you are running out of memory, you can reduce the "longest_edge".
|
|
|
|
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
|
|
|
|
mm_processor_kwargs={
|
|
|
|
"size": {
|
|
|
|
"longest_edge": 3 * 364
|
|
|
|
},
|
|
|
|
},
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-12-16 19:23:33 +08:00
|
|
|
)
|
|
|
|
prompt = (
|
|
|
|
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
|
|
|
)
|
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# InternVL
|
|
|
|
def run_internvl(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
model_name = "OpenGVLab/InternVL2-2B"
|
|
|
|
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
|
|
|
trust_remote_code=True,
|
|
|
|
max_model_len=4096,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-12-16 19:23:33 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
|
|
trust_remote_code=True)
|
|
|
|
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
|
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True)
|
|
|
|
|
|
|
|
# Stop tokens for InternVL
|
|
|
|
# models variants may have different stop tokens
|
|
|
|
# please refer to the model card for the correct "stop words":
|
|
|
|
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
|
|
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
|
|
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
# LLaVA-1.5
|
2024-10-07 19:55:12 +08:00
|
|
|
def run_llava(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
|
|
|
|
2024-12-11 19:55:30 -05:00
|
|
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
|
|
|
max_model_len=4096,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-08-08 22:02:41 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
|
|
|
|
# LLaVA-1.6/LLaVA-NeXT
|
2024-10-07 19:55:12 +08:00
|
|
|
def run_llava_next(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
prompt = f"[INST] <image>\n{question} [/INST]"
|
2024-12-11 19:55:30 -05:00
|
|
|
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
|
|
|
max_model_len=8192,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-09-11 13:21:36 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
# LlaVA-NeXT-Video
|
|
|
|
# Currently only support for video input
|
2024-10-07 19:55:12 +08:00
|
|
|
def run_llava_next_video(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "video"
|
|
|
|
|
2024-09-11 13:21:36 +08:00
|
|
|
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
2024-12-11 19:55:30 -05:00
|
|
|
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
|
|
|
max_model_len=8192,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-08-08 22:02:41 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
|
2024-09-23 01:51:44 +08:00
|
|
|
# LLaVA-OneVision
|
2024-10-07 19:55:12 +08:00
|
|
|
def run_llava_onevision(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
|
|
|
|
if modality == "video":
|
|
|
|
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
|
|
|
|
<|im_start|>assistant\n"
|
|
|
|
|
|
|
|
elif modality == "image":
|
|
|
|
prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
|
|
|
|
<|im_start|>assistant\n"
|
|
|
|
|
|
|
|
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
2024-12-11 19:55:30 -05:00
|
|
|
max_model_len=16384,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-09-23 01:51:44 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Mantis
|
|
|
|
def run_mantis(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-07-26 22:44:13 -07:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
|
|
|
prompt = llama3_template.format(f"{question}\n<image>")
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
llm = LLM(
|
2024-12-16 19:23:33 +08:00
|
|
|
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
2024-09-29 00:54:35 +08:00
|
|
|
max_model_len=4096,
|
2024-12-16 19:23:33 +08:00
|
|
|
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-07-26 22:44:13 -07:00
|
|
|
)
|
2024-12-16 19:23:33 +08:00
|
|
|
stop_token_ids = [128009]
|
2024-08-08 22:02:41 +08:00
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
|
|
|
|
# MiniCPM-V
|
2024-10-07 19:55:12 +08:00
|
|
|
def run_minicpmv(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
# 2.0
|
|
|
|
# The official repo doesn't work yet, so we need to use a fork for now
|
|
|
|
# For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
|
|
|
|
# model_name = "HwwwH/MiniCPM-V-2"
|
|
|
|
|
|
|
|
# 2.5
|
2024-08-08 22:02:41 +08:00
|
|
|
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
|
|
|
|
|
|
|
|
#2.6
|
|
|
|
model_name = "openbmb/MiniCPM-V-2_6"
|
2024-07-26 22:44:13 -07:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
|
|
trust_remote_code=True)
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
2024-09-29 00:54:35 +08:00
|
|
|
max_model_len=4096,
|
|
|
|
max_num_seqs=2,
|
2024-07-26 22:44:13 -07:00
|
|
|
trust_remote_code=True,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-07-26 22:44:13 -07:00
|
|
|
)
|
2024-08-08 22:02:41 +08:00
|
|
|
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
|
|
|
# 2.0
|
|
|
|
# stop_token_ids = [tokenizer.eos_id]
|
|
|
|
|
|
|
|
# 2.5
|
|
|
|
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
|
|
|
|
|
|
|
# 2.6
|
|
|
|
stop_tokens = ['<|im_end|>', '<|endoftext|>']
|
|
|
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
messages = [{
|
|
|
|
'role': 'user',
|
|
|
|
'content': f'(<image>./</image>)\n{question}'
|
|
|
|
}]
|
|
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True)
|
2024-08-08 22:02:41 +08:00
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# LLama 3.2
|
|
|
|
def run_mllama(question: str, modality: str):
|
2024-11-03 18:15:36 -06:00
|
|
|
assert modality == "image"
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
2024-11-03 18:15:36 -06:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Note: The default setting of max_num_seqs (256) and
|
|
|
|
# max_model_len (131072) for this model may cause OOM.
|
|
|
|
# You may lower either to run this example on lower-end GPUs.
|
|
|
|
|
|
|
|
# The configuration below has been confirmed to launch on a single L40 GPU.
|
2024-11-03 18:15:36 -06:00
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
2024-12-16 19:23:33 +08:00
|
|
|
max_model_len=4096,
|
|
|
|
max_num_seqs=16,
|
|
|
|
enforce_eager=True,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-11-03 18:15:36 -06:00
|
|
|
)
|
|
|
|
|
2024-12-28 08:31:10 +08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
messages = [{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [{
|
|
|
|
"type": "image"
|
|
|
|
}, {
|
|
|
|
"type": "text",
|
|
|
|
"text": f"{question}"
|
|
|
|
}]
|
|
|
|
}]
|
|
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
|
|
add_generation_prompt=True,
|
|
|
|
tokenize=False)
|
2024-12-16 19:23:33 +08:00
|
|
|
stop_token_ids = None
|
2024-11-03 18:15:36 -06:00
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Molmo
|
|
|
|
def run_molmo(question, modality):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
model_name = "allenai/Molmo-7B-D-0924"
|
2024-08-09 22:51:04 +08:00
|
|
|
|
2024-07-29 18:16:30 +08:00
|
|
|
llm = LLM(
|
2024-08-09 22:51:04 +08:00
|
|
|
model=model_name,
|
2024-07-29 18:16:30 +08:00
|
|
|
trust_remote_code=True,
|
2024-12-16 19:23:33 +08:00
|
|
|
dtype="bfloat16",
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-07-29 18:16:30 +08:00
|
|
|
)
|
2024-08-09 22:51:04 +08:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
prompt = question
|
|
|
|
stop_token_ids = None
|
2024-08-08 22:02:41 +08:00
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-29 18:16:30 +08:00
|
|
|
|
|
|
|
|
2024-10-07 19:55:12 +08:00
|
|
|
# NVLM-D
|
|
|
|
def run_nvlm_d(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
model_name = "nvidia/NVLM-D-72B"
|
|
|
|
|
|
|
|
# Adjust this as necessary to fit in GPU
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
|
|
|
trust_remote_code=True,
|
|
|
|
max_model_len=4096,
|
|
|
|
tensor_parallel_size=4,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-10-07 19:55:12 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
|
|
trust_remote_code=True)
|
|
|
|
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
|
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True)
|
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# PaliGemma
|
|
|
|
def run_paligemma(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-07-27 19:53:07 +08:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# PaliGemma has special prompt format for VQA
|
|
|
|
prompt = "caption en"
|
|
|
|
llm = LLM(model="google/paligemma-3b-mix-224",
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-08-08 22:02:41 +08:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
2024-07-27 19:53:07 +08:00
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# PaliGemma 2
|
|
|
|
def run_paligemma2(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
2024-09-05 06:48:10 -06:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# PaliGemma 2 has special prompt format for VQA
|
|
|
|
prompt = "caption en"
|
|
|
|
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
2024-09-05 06:48:10 -06:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Phi-3-Vision
|
|
|
|
def run_phi3v(question: str, modality: str):
|
2024-09-23 01:51:44 +08:00
|
|
|
assert modality == "image"
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
|
2024-09-12 00:31:19 +08:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# num_crops is an override kwarg to the multimodal image processor;
|
|
|
|
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
|
|
|
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
|
|
|
#
|
|
|
|
# Generally speaking, a larger value for num_crops results in more
|
|
|
|
# tokens per image instance, because it may scale the image more in
|
|
|
|
# the image preprocessing. Some references in the model docs and the
|
|
|
|
# formula for image tokens after the preprocessing
|
|
|
|
# transform can be found below.
|
|
|
|
#
|
|
|
|
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
|
|
|
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
2024-09-12 00:31:19 +08:00
|
|
|
llm = LLM(
|
2024-12-16 19:23:33 +08:00
|
|
|
model="microsoft/Phi-3.5-vision-instruct",
|
|
|
|
trust_remote_code=True,
|
2024-10-31 10:10:52 -06:00
|
|
|
max_model_len=4096,
|
2024-12-16 19:23:33 +08:00
|
|
|
max_num_seqs=2,
|
2024-10-23 08:05:18 -06:00
|
|
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
2024-12-16 19:23:33 +08:00
|
|
|
mm_processor_kwargs={"num_crops": 16},
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-09-12 00:31:19 +08:00
|
|
|
)
|
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-10-18 15:29:56 -04:00
|
|
|
# Pixtral HF-format
|
|
|
|
def run_pixtral_hf(question: str, modality: str):
|
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
model_name = "mistral-community/pixtral-12b"
|
|
|
|
|
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
|
|
|
max_model_len=8192,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-10-18 15:29:56 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
|
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Qwen
|
|
|
|
def run_qwen_vl(question: str, modality: str):
|
2024-10-14 07:56:24 -07:00
|
|
|
assert modality == "image"
|
|
|
|
|
|
|
|
llm = LLM(
|
2024-12-16 19:23:33 +08:00
|
|
|
model="Qwen/Qwen-VL",
|
2024-10-14 07:56:24 -07:00
|
|
|
trust_remote_code=True,
|
2024-12-16 19:23:33 +08:00
|
|
|
max_model_len=1024,
|
|
|
|
max_num_seqs=2,
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-10-14 07:56:24 -07:00
|
|
|
)
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
prompt = f"{question}Picture 1: <img></img>\n"
|
2024-10-14 07:56:24 -07:00
|
|
|
stop_token_ids = None
|
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
# Qwen2-VL
|
|
|
|
def run_qwen2_vl(question: str, modality: str):
|
2024-10-12 01:36:13 +08:00
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
2024-11-06 19:41:17 +08:00
|
|
|
|
2024-11-08 17:56:58 +08:00
|
|
|
llm = LLM(
|
|
|
|
model=model_name,
|
2024-12-16 19:23:33 +08:00
|
|
|
max_model_len=4096,
|
|
|
|
max_num_seqs=5,
|
|
|
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
2024-11-08 17:56:58 +08:00
|
|
|
mm_processor_kwargs={
|
2024-12-16 19:23:33 +08:00
|
|
|
"min_pixels": 28 * 28,
|
|
|
|
"max_pixels": 1280 * 28 * 28,
|
2024-11-08 17:56:58 +08:00
|
|
|
},
|
2024-12-18 18:54:46 -05:00
|
|
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
2024-11-08 17:56:58 +08:00
|
|
|
)
|
2024-11-06 19:41:17 +08:00
|
|
|
|
2024-12-20 00:28:00 +08:00
|
|
|
if modality == "image":
|
|
|
|
placeholder = "<|image_pad|>"
|
|
|
|
elif modality == "video":
|
|
|
|
placeholder = "<|video_pad|>"
|
|
|
|
|
2024-12-16 19:23:33 +08:00
|
|
|
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
2024-12-20 00:28:00 +08:00
|
|
|
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
2024-12-16 19:23:33 +08:00
|
|
|
f"{question}<|im_end|>\n"
|
|
|
|
"<|im_start|>assistant\n")
|
|
|
|
stop_token_ids = None
|
2024-12-08 01:10:05 +08:00
|
|
|
return llm, prompt, stop_token_ids
|
|
|
|
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
model_example_map = {
|
2024-12-16 19:23:33 +08:00
|
|
|
"aria": run_aria,
|
|
|
|
"blip-2": run_blip2,
|
|
|
|
"chameleon": run_chameleon,
|
|
|
|
"fuyu": run_fuyu,
|
|
|
|
"glm4v": run_glm4v,
|
|
|
|
"h2ovl_chat": run_h2ovl,
|
|
|
|
"idefics3": run_idefics3,
|
|
|
|
"internvl_chat": run_internvl,
|
2024-07-26 22:44:13 -07:00
|
|
|
"llava": run_llava,
|
|
|
|
"llava-next": run_llava_next,
|
2024-09-11 13:21:36 +08:00
|
|
|
"llava-next-video": run_llava_next_video,
|
2024-09-23 01:51:44 +08:00
|
|
|
"llava-onevision": run_llava_onevision,
|
2024-12-16 19:23:33 +08:00
|
|
|
"mantis": run_mantis,
|
2024-07-26 22:44:13 -07:00
|
|
|
"minicpmv": run_minicpmv,
|
2024-12-16 19:23:33 +08:00
|
|
|
"mllama": run_mllama,
|
|
|
|
"molmo": run_molmo,
|
2024-10-07 19:55:12 +08:00
|
|
|
"NVLM_D": run_nvlm_d,
|
2024-12-16 19:23:33 +08:00
|
|
|
"paligemma": run_paligemma,
|
|
|
|
"paligemma2": run_paligemma2,
|
|
|
|
"phi3_v": run_phi3v,
|
|
|
|
"pixtral_hf": run_pixtral_hf,
|
2024-09-05 06:48:10 -06:00
|
|
|
"qwen_vl": run_qwen_vl,
|
2024-09-12 00:31:19 +08:00
|
|
|
"qwen2_vl": run_qwen2_vl,
|
2024-07-26 22:44:13 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-09-11 13:21:36 +08:00
|
|
|
def get_multi_modal_input(args):
|
|
|
|
"""
|
|
|
|
return {
|
|
|
|
"data": image or video,
|
|
|
|
"question": question,
|
|
|
|
}
|
|
|
|
"""
|
|
|
|
if args.modality == "image":
|
|
|
|
# Input image and question
|
|
|
|
image = ImageAsset("cherry_blossom") \
|
|
|
|
.pil_image.convert("RGB")
|
|
|
|
img_question = "What is the content of this image?"
|
|
|
|
|
|
|
|
return {
|
|
|
|
"data": image,
|
|
|
|
"question": img_question,
|
|
|
|
}
|
|
|
|
|
|
|
|
if args.modality == "video":
|
|
|
|
# Input video and question
|
|
|
|
video = VideoAsset(name="sample_demo_1.mp4",
|
|
|
|
num_frames=args.num_frames).np_ndarrays
|
|
|
|
vid_question = "Why is this video funny?"
|
|
|
|
|
|
|
|
return {
|
|
|
|
"data": video,
|
|
|
|
"question": vid_question,
|
|
|
|
}
|
|
|
|
|
|
|
|
msg = f"Modality {args.modality} is not supported."
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
2024-12-11 19:55:30 -05:00
|
|
|
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
|
|
|
|
"""Repeats images with provided probability of "image_repeat_prob".
|
|
|
|
Used to simulate hit/miss for the MM preprocessor cache.
|
|
|
|
"""
|
|
|
|
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
|
|
|
|
no_yes = [0, 1]
|
|
|
|
probs = [1.0 - image_repeat_prob, image_repeat_prob]
|
|
|
|
|
|
|
|
inputs = []
|
|
|
|
cur_image = data
|
|
|
|
for i in range(num_prompts):
|
|
|
|
if image_repeat_prob is not None:
|
|
|
|
res = random.choices(no_yes, probs)[0]
|
|
|
|
if res == 0:
|
|
|
|
# No repeat => Modify one pixel
|
|
|
|
cur_image = cur_image.copy()
|
|
|
|
new_val = (i // 256 // 256, i // 256, i % 256)
|
|
|
|
cur_image.putpixel((0, 0), new_val)
|
|
|
|
|
|
|
|
inputs.append({
|
|
|
|
"prompt": prompt,
|
|
|
|
"multi_modal_data": {
|
|
|
|
modality: cur_image
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
def main(args):
|
|
|
|
model = args.model_type
|
|
|
|
if model not in model_example_map:
|
|
|
|
raise ValueError(f"Model type {model} is not supported.")
|
|
|
|
|
2024-09-11 13:21:36 +08:00
|
|
|
modality = args.modality
|
|
|
|
mm_input = get_multi_modal_input(args)
|
|
|
|
data = mm_input["data"]
|
|
|
|
question = mm_input["question"]
|
|
|
|
|
2024-09-23 01:51:44 +08:00
|
|
|
llm, prompt, stop_token_ids = model_example_map[model](question, modality)
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
# We set temperature to 0.2 so that outputs can be different
|
|
|
|
# even when all prompts are identical when running batch inference.
|
2024-08-08 22:02:41 +08:00
|
|
|
sampling_params = SamplingParams(temperature=0.2,
|
|
|
|
max_tokens=64,
|
|
|
|
stop_token_ids=stop_token_ids)
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
assert args.num_prompts > 0
|
|
|
|
if args.num_prompts == 1:
|
|
|
|
# Single inference
|
|
|
|
inputs = {
|
|
|
|
"prompt": prompt,
|
|
|
|
"multi_modal_data": {
|
2024-09-11 13:21:36 +08:00
|
|
|
modality: data
|
2024-07-26 22:44:13 -07:00
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
else:
|
|
|
|
# Batch inference
|
2024-12-11 19:55:30 -05:00
|
|
|
if args.image_repeat_prob is not None:
|
|
|
|
# Repeat images with specified probability of "image_repeat_prob"
|
|
|
|
inputs = apply_image_repeat(args.image_repeat_prob,
|
|
|
|
args.num_prompts, data, prompt,
|
|
|
|
modality)
|
|
|
|
else:
|
|
|
|
# Use the same image for all prompts
|
|
|
|
inputs = [{
|
|
|
|
"prompt": prompt,
|
|
|
|
"multi_modal_data": {
|
|
|
|
modality: data
|
|
|
|
},
|
|
|
|
} for _ in range(args.num_prompts)]
|
|
|
|
|
|
|
|
if args.time_generate:
|
|
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
|
|
|
elapsed_time = time.time() - start_time
|
|
|
|
print("-- generate time = {}".format(elapsed_time))
|
2024-07-26 22:44:13 -07:00
|
|
|
|
2024-12-11 19:55:30 -05:00
|
|
|
else:
|
|
|
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
2024-07-26 22:44:13 -07:00
|
|
|
|
|
|
|
for o in outputs:
|
|
|
|
generated_text = o.outputs[0].text
|
|
|
|
print(generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = FlexibleArgumentParser(
|
|
|
|
description='Demo on using vLLM for offline inference with '
|
2024-10-23 11:35:29 +08:00
|
|
|
'vision language models for text generation')
|
2024-07-26 22:44:13 -07:00
|
|
|
parser.add_argument('--model-type',
|
|
|
|
'-m',
|
|
|
|
type=str,
|
|
|
|
default="llava",
|
|
|
|
choices=model_example_map.keys(),
|
|
|
|
help='Huggingface "model_type".')
|
|
|
|
parser.add_argument('--num-prompts',
|
|
|
|
type=int,
|
2024-09-11 13:21:36 +08:00
|
|
|
default=4,
|
2024-07-26 22:44:13 -07:00
|
|
|
help='Number of prompts to run.')
|
2024-09-11 13:21:36 +08:00
|
|
|
parser.add_argument('--modality',
|
|
|
|
type=str,
|
|
|
|
default="image",
|
2024-09-23 01:51:44 +08:00
|
|
|
choices=['image', 'video'],
|
2024-09-11 13:21:36 +08:00
|
|
|
help='Modality of the input.')
|
|
|
|
parser.add_argument('--num-frames',
|
|
|
|
type=int,
|
|
|
|
default=16,
|
|
|
|
help='Number of frames to extract from the video.')
|
2024-12-11 19:55:30 -05:00
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--image-repeat-prob',
|
|
|
|
type=float,
|
|
|
|
default=None,
|
|
|
|
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
|
|
|
' (if enabled)')
|
|
|
|
|
|
|
|
parser.add_argument(
|
2024-12-18 18:54:46 -05:00
|
|
|
'--disable-mm-preprocessor-cache',
|
2024-12-11 19:55:30 -05:00
|
|
|
action='store_true',
|
2024-12-18 18:54:46 -05:00
|
|
|
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
2024-12-11 19:55:30 -05:00
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
'--time-generate',
|
|
|
|
action='store_true',
|
|
|
|
help='If True, then print the total generate() call time')
|
|
|
|
|
2024-07-26 22:44:13 -07:00
|
|
|
args = parser.parse_args()
|
2024-11-08 17:56:58 +08:00
|
|
|
main(args)
|