
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
244 lines
8.4 KiB
Python
244 lines
8.4 KiB
Python
"""
|
|
This example shows how to use vLLM for running offline inference with
|
|
multi-image input on vision language models, using the chat template defined
|
|
by the model.
|
|
"""
|
|
from argparse import Namespace
|
|
from typing import List
|
|
|
|
from transformers import AutoProcessor, AutoTokenizer
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.multimodal.utils import fetch_image
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
QUESTION = "What is the content of each image?"
|
|
IMAGE_URLS = [
|
|
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
|
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
|
]
|
|
|
|
|
|
def load_qwenvl_chat(question: str, image_urls: List[str]):
|
|
model_name = "Qwen/Qwen-VL-Chat"
|
|
llm = LLM(
|
|
model=model_name,
|
|
trust_remote_code=True,
|
|
max_num_seqs=5,
|
|
limit_mm_per_prompt={"image": len(image_urls)},
|
|
)
|
|
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
|
for i, _ in enumerate(image_urls, start=1))
|
|
|
|
# This model does not have a chat_template attribute on its tokenizer,
|
|
# so we need to explicitly pass it. We use ChatML since it's used in the
|
|
# generation utils of the model:
|
|
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
trust_remote_code=True)
|
|
|
|
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
|
|
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
|
|
|
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
chat_template=chat_template)
|
|
|
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
|
return llm, prompt, stop_token_ids, None, chat_template
|
|
|
|
|
|
def load_phi3v(question: str, image_urls: List[str]):
|
|
llm = LLM(
|
|
model="microsoft/Phi-3.5-vision-instruct",
|
|
trust_remote_code=True,
|
|
max_model_len=4096,
|
|
limit_mm_per_prompt={"image": len(image_urls)},
|
|
)
|
|
placeholders = "\n".join(f"<|image_{i}|>"
|
|
for i, _ in enumerate(image_urls, start=1))
|
|
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
|
stop_token_ids = None
|
|
return llm, prompt, stop_token_ids, None, None
|
|
|
|
|
|
def load_internvl(question: str, image_urls: List[str]):
|
|
model_name = "OpenGVLab/InternVL2-2B"
|
|
|
|
llm = LLM(
|
|
model=model_name,
|
|
trust_remote_code=True,
|
|
max_num_seqs=5,
|
|
max_model_len=4096,
|
|
limit_mm_per_prompt={"image": len(image_urls)},
|
|
)
|
|
|
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
|
for i, _ in enumerate(image_urls, start=1))
|
|
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
trust_remote_code=True)
|
|
prompt = tokenizer.apply_chat_template(messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True)
|
|
|
|
# Stop tokens for InternVL
|
|
# models variants may have different stop tokens
|
|
# please refer to the model card for the correct "stop words":
|
|
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
|
|
|
return llm, prompt, stop_token_ids, None, None
|
|
|
|
|
|
def load_qwen2_vl(question, image_urls: List[str]):
|
|
try:
|
|
from qwen_vl_utils import process_vision_info
|
|
except ModuleNotFoundError:
|
|
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
|
'be automatically resized. You can enable this functionality by '
|
|
'`pip install qwen-vl-utils`.')
|
|
process_vision_info = None
|
|
|
|
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
|
|
|
llm = LLM(
|
|
model=model_name,
|
|
max_num_seqs=5,
|
|
max_model_len=32768 if process_vision_info is None else 4096,
|
|
limit_mm_per_prompt={"image": len(image_urls)},
|
|
)
|
|
|
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
|
messages = [{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant."
|
|
}, {
|
|
"role":
|
|
"user",
|
|
"content": [
|
|
*placeholders,
|
|
{
|
|
"type": "text",
|
|
"text": question
|
|
},
|
|
],
|
|
}]
|
|
|
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
prompt = processor.apply_chat_template(messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True)
|
|
|
|
stop_token_ids = None
|
|
|
|
if process_vision_info is None:
|
|
image_data = [fetch_image(url) for url in image_urls]
|
|
else:
|
|
image_data, _ = process_vision_info(messages)
|
|
|
|
return llm, prompt, stop_token_ids, image_data, None
|
|
|
|
|
|
model_example_map = {
|
|
"phi3_v": load_phi3v,
|
|
"internvl_chat": load_internvl,
|
|
"qwen2_vl": load_qwen2_vl,
|
|
"qwen_vl_chat": load_qwenvl_chat,
|
|
}
|
|
|
|
|
|
def run_generate(model, question: str, image_urls: List[str]):
|
|
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
|
|
question, image_urls)
|
|
if image_data is None:
|
|
image_data = [fetch_image(url) for url in image_urls]
|
|
|
|
sampling_params = SamplingParams(temperature=0.0,
|
|
max_tokens=128,
|
|
stop_token_ids=stop_token_ids)
|
|
|
|
outputs = llm.generate(
|
|
{
|
|
"prompt": prompt,
|
|
"multi_modal_data": {
|
|
"image": image_data
|
|
},
|
|
},
|
|
sampling_params=sampling_params)
|
|
|
|
for o in outputs:
|
|
generated_text = o.outputs[0].text
|
|
print(generated_text)
|
|
|
|
|
|
def run_chat(model: str, question: str, image_urls: List[str]):
|
|
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
|
|
question, image_urls)
|
|
|
|
sampling_params = SamplingParams(temperature=0.0,
|
|
max_tokens=128,
|
|
stop_token_ids=stop_token_ids)
|
|
outputs = llm.chat(
|
|
[{
|
|
"role":
|
|
"user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": question,
|
|
},
|
|
*({
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": image_url
|
|
},
|
|
} for image_url in image_urls),
|
|
],
|
|
}],
|
|
sampling_params=sampling_params,
|
|
chat_template=chat_template,
|
|
)
|
|
|
|
for o in outputs:
|
|
generated_text = o.outputs[0].text
|
|
print(generated_text)
|
|
|
|
|
|
def main(args: Namespace):
|
|
model = args.model_type
|
|
method = args.method
|
|
|
|
if method == "generate":
|
|
run_generate(model, QUESTION, IMAGE_URLS)
|
|
elif method == "chat":
|
|
run_chat(model, QUESTION, IMAGE_URLS)
|
|
else:
|
|
raise ValueError(f"Invalid method: {method}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = FlexibleArgumentParser(
|
|
description='Demo on using vLLM for offline inference with '
|
|
'vision language models that support multi-image input')
|
|
parser.add_argument('--model-type',
|
|
'-m',
|
|
type=str,
|
|
default="phi3_v",
|
|
choices=model_example_map.keys(),
|
|
help='Huggingface "model_type".')
|
|
parser.add_argument("--method",
|
|
type=str,
|
|
default="generate",
|
|
choices=["generate", "chat"],
|
|
help="The method to run in `vllm.LLM`.")
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|