vllm/examples/offline_inference_vision_language_multi_image.py
Alex Brooks c6202daeed
[Model] Support multiple images for qwen-vl (#8247)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
2024-09-12 10:10:54 -07:00

244 lines
8.4 KiB
Python

"""
This example shows how to use vLLM for running offline inference with
multi-image input on vision language models, using the chat template defined
by the model.
"""
from argparse import Namespace
from typing import List
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
]
def load_qwenvl_chat(question: str, image_urls: List[str]):
model_name = "Qwen/Qwen-VL-Chat"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "".join(f"Picture {i}: <img></img>\n"
for i, _ in enumerate(image_urls, start=1))
# This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template)
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids, None, chat_template
def load_phi3v(question: str, image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids, None, None
def load_internvl(question: str, image_urls: List[str]):
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids, None, None
def load_qwen2_vl(question, image_urls: List[str]):
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)
return llm, prompt, stop_token_ids, image_data, None
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
}
def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image_data
},
},
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
question, image_urls)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)
outputs = llm.chat(
[{
"role":
"user",
"content": [
{
"type": "text",
"text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
},
} for image_url in image_urls),
],
}],
sampling_params=sampling_params,
chat_template=chat_template,
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main(args: Namespace):
model = args.model_type
method = args.method
if method == "generate":
run_generate(model, QUESTION, IMAGE_URLS)
elif method == "chat":
run_chat(model, QUESTION, IMAGE_URLS)
else:
raise ValueError(f"Invalid method: {method}")
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input')
parser.add_argument('--model-type',
'-m',
type=str,
default="phi3_v",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument("--method",
type=str,
default="generate",
choices=["generate", "chat"],
help="The method to run in `vllm.LLM`.")
args = parser.parse_args()
main(args)