2024-09-05 18:51:53 +08:00
"""
This example shows how to use vLLM for running offline inference with
2024-10-23 11:35:29 +08:00
multi - image input on vision language models for text generation ,
using the chat template defined by the model .
2024-09-05 18:51:53 +08:00
"""
from argparse import Namespace
2024-09-22 06:56:20 -06:00
from typing import List , NamedTuple , Optional
2024-09-05 18:51:53 +08:00
2024-09-22 06:56:20 -06:00
from PIL . Image import Image
2024-09-12 00:31:19 +08:00
from transformers import AutoProcessor , AutoTokenizer
2024-09-07 16:38:23 +08:00
from vllm import LLM , SamplingParams
2024-09-05 18:51:53 +08:00
from vllm . multimodal . utils import fetch_image
from vllm . utils import FlexibleArgumentParser
QUESTION = " What is the content of each image? "
IMAGE_URLS = [
" https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy % C5 % BCowka_w_wodzie_ %28s amiec % 29.jpg " ,
" https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg " ,
]
2024-09-22 06:56:20 -06:00
class ModelRequestData ( NamedTuple ) :
llm : LLM
prompt : str
stop_token_ids : Optional [ List [ str ] ]
image_data : List [ Image ]
chat_template : Optional [ str ]
2024-09-29 00:54:35 +08:00
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
2024-09-22 06:56:20 -06:00
def load_qwenvl_chat ( question : str , image_urls : List [ str ] ) - > ModelRequestData :
2024-09-12 11:10:54 -06:00
model_name = " Qwen/Qwen-VL-Chat "
llm = LLM (
model = model_name ,
trust_remote_code = True ,
2024-09-29 00:54:35 +08:00
max_model_len = 1024 ,
max_num_seqs = 2 ,
2024-09-12 11:10:54 -06:00
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
)
placeholders = " " . join ( f " Picture { i } : <img></img> \n "
for i , _ in enumerate ( image_urls , start = 1 ) )
# This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer . from_pretrained ( model_name ,
trust_remote_code = True )
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = " { % i f not add_generation_prompt is defined % } { % s et add_generation_prompt = false % } { % e ndif % } { % f or message in messages % } {{ ' <|im_start|> ' + message[ ' role ' ] + ' \n ' + message[ ' content ' ] + ' <|im_end|> ' + ' \n ' }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|im_start|>assistant \n ' }} { % e ndif % } " # noqa: E501
messages = [ { ' role ' : ' user ' , ' content ' : f " { placeholders } \n { question } " } ]
prompt = tokenizer . apply_chat_template ( messages ,
tokenize = False ,
add_generation_prompt = True ,
chat_template = chat_template )
stop_tokens = [ " <|endoftext|> " , " <|im_start|> " , " <|im_end|> " ]
stop_token_ids = [ tokenizer . convert_tokens_to_ids ( i ) for i in stop_tokens ]
2024-09-22 06:56:20 -06:00
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
image_data = [ fetch_image ( url ) for url in image_urls ] ,
chat_template = chat_template ,
)
2024-09-12 11:10:54 -06:00
2024-09-22 06:56:20 -06:00
def load_phi3v ( question : str , image_urls : List [ str ] ) - > ModelRequestData :
2024-09-24 01:36:46 -06:00
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame.
#
# Generally speaking, a larger value for num_crops results in more
# tokens per image instance, because it may scale the image more in
# the image preprocessing. Some references in the model docs and the
# formula for image tokens after the preprocessing
# transform can be found below.
#
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
2024-09-07 16:38:23 +08:00
llm = LLM (
2024-09-05 18:51:53 +08:00
model = " microsoft/Phi-3.5-vision-instruct " ,
trust_remote_code = True ,
max_model_len = 4096 ,
2024-09-29 00:54:35 +08:00
max_num_seqs = 2 ,
2024-09-05 18:51:53 +08:00
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
2024-09-24 01:36:46 -06:00
mm_processor_kwargs = { " num_crops " : 4 } ,
2024-09-05 18:51:53 +08:00
)
placeholders = " \n " . join ( f " <|image_ { i } |> "
for i , _ in enumerate ( image_urls , start = 1 ) )
prompt = f " <|user|> \n { placeholders } \n { question } <|end|> \n <|assistant|> \n "
2024-09-07 16:38:23 +08:00
stop_token_ids = None
2024-09-22 06:56:20 -06:00
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
image_data = [ fetch_image ( url ) for url in image_urls ] ,
chat_template = None ,
)
2024-09-05 18:51:53 +08:00
2024-09-07 16:38:23 +08:00
2024-09-22 06:56:20 -06:00
def load_internvl ( question : str , image_urls : List [ str ] ) - > ModelRequestData :
2024-09-07 16:38:23 +08:00
model_name = " OpenGVLab/InternVL2-2B "
llm = LLM (
model = model_name ,
trust_remote_code = True ,
max_model_len = 4096 ,
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
2024-09-30 13:01:20 +08:00
mm_processor_kwargs = { " max_dynamic_patch " : 4 } ,
2024-09-07 16:38:23 +08:00
)
placeholders = " \n " . join ( f " Image- { i } : <image> \n "
for i , _ in enumerate ( image_urls , start = 1 ) )
messages = [ { ' role ' : ' user ' , ' content ' : f " { placeholders } \n { question } " } ]
tokenizer = AutoTokenizer . from_pretrained ( model_name ,
trust_remote_code = True )
prompt = tokenizer . apply_chat_template ( messages ,
tokenize = False ,
add_generation_prompt = True )
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = [ " <|endoftext|> " , " <|im_start|> " , " <|im_end|> " , " <|end|> " ]
stop_token_ids = [ tokenizer . convert_tokens_to_ids ( i ) for i in stop_tokens ]
2024-09-12 00:31:19 +08:00
2024-09-22 06:56:20 -06:00
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
image_data = [ fetch_image ( url ) for url in image_urls ] ,
chat_template = None ,
)
2024-09-12 00:31:19 +08:00
2024-10-07 19:55:12 +08:00
def load_nvlm_d ( question : str , image_urls : List [ str ] ) :
model_name = " nvidia/NVLM-D-72B "
# Adjust this as necessary to fit in GPU
llm = LLM (
model = model_name ,
trust_remote_code = True ,
max_model_len = 8192 ,
tensor_parallel_size = 4 ,
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
mm_processor_kwargs = { " max_dynamic_patch " : 4 } ,
)
placeholders = " \n " . join ( f " Image- { i } : <image> \n "
for i , _ in enumerate ( image_urls , start = 1 ) )
messages = [ { ' role ' : ' user ' , ' content ' : f " { placeholders } \n { question } " } ]
tokenizer = AutoTokenizer . from_pretrained ( model_name ,
trust_remote_code = True )
prompt = tokenizer . apply_chat_template ( messages ,
tokenize = False ,
add_generation_prompt = True )
stop_token_ids = None
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
image_data = [ fetch_image ( url ) for url in image_urls ] ,
chat_template = None ,
)
2024-09-22 06:56:20 -06:00
def load_qwen2_vl ( question , image_urls : List [ str ] ) - > ModelRequestData :
2024-09-12 00:31:19 +08:00
try :
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError :
print ( ' WARNING: `qwen-vl-utils` not installed, input images will not '
' be automatically resized. You can enable this functionality by '
' `pip install qwen-vl-utils`. ' )
process_vision_info = None
model_name = " Qwen/Qwen2-VL-7B-Instruct "
2024-09-29 00:54:35 +08:00
# Tested on L40
2024-09-12 00:31:19 +08:00
llm = LLM (
model = model_name ,
max_model_len = 32768 if process_vision_info is None else 4096 ,
2024-09-29 00:54:35 +08:00
max_num_seqs = 5 ,
2024-09-12 00:31:19 +08:00
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
)
placeholders = [ { " type " : " image " , " image " : url } for url in image_urls ]
messages = [ {
" role " : " system " ,
" content " : " You are a helpful assistant. "
} , {
" role " :
" user " ,
" content " : [
* placeholders ,
{
" type " : " text " ,
" text " : question
} ,
] ,
} ]
processor = AutoProcessor . from_pretrained ( model_name )
prompt = processor . apply_chat_template ( messages ,
tokenize = False ,
add_generation_prompt = True )
stop_token_ids = None
if process_vision_info is None :
image_data = [ fetch_image ( url ) for url in image_urls ]
else :
image_data , _ = process_vision_info ( messages )
2024-09-22 06:56:20 -06:00
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
image_data = image_data ,
chat_template = None ,
)
2024-09-07 16:38:23 +08:00
2024-10-14 15:24:26 -07:00
def load_mllama ( question , image_urls : List [ str ] ) - > ModelRequestData :
model_name = " meta-llama/Llama-3.2-11B-Vision-Instruct "
# The configuration below has been confirmed to launch on a single L40 GPU.
llm = LLM (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 16 ,
enforce_eager = True ,
limit_mm_per_prompt = { " image " : len ( image_urls ) } ,
)
prompt = f " <|image|><|image|><|begin_of_text|> { question } "
return ModelRequestData (
llm = llm ,
prompt = prompt ,
stop_token_ids = None ,
image_data = [ fetch_image ( url ) for url in image_urls ] ,
chat_template = None ,
)
2024-09-07 16:38:23 +08:00
model_example_map = {
" phi3_v " : load_phi3v ,
" internvl_chat " : load_internvl ,
2024-10-07 19:55:12 +08:00
" NVLM_D " : load_nvlm_d ,
2024-09-12 00:31:19 +08:00
" qwen2_vl " : load_qwen2_vl ,
2024-09-12 11:10:54 -06:00
" qwen_vl_chat " : load_qwenvl_chat ,
2024-10-14 15:24:26 -07:00
" mllama " : load_mllama ,
2024-09-07 16:38:23 +08:00
}
def run_generate ( model , question : str , image_urls : List [ str ] ) :
2024-09-22 06:56:20 -06:00
req_data = model_example_map [ model ] ( question , image_urls )
2024-09-07 16:38:23 +08:00
sampling_params = SamplingParams ( temperature = 0.0 ,
max_tokens = 128 ,
2024-09-22 06:56:20 -06:00
stop_token_ids = req_data . stop_token_ids )
2024-09-07 16:38:23 +08:00
2024-09-22 06:56:20 -06:00
outputs = req_data . llm . generate (
2024-09-07 16:38:23 +08:00
{
2024-09-22 06:56:20 -06:00
" prompt " : req_data . prompt ,
2024-09-07 16:38:23 +08:00
" multi_modal_data " : {
2024-09-22 06:56:20 -06:00
" image " : req_data . image_data
2024-09-07 16:38:23 +08:00
} ,
2024-09-05 18:51:53 +08:00
} ,
2024-09-07 16:38:23 +08:00
sampling_params = sampling_params )
2024-09-05 18:51:53 +08:00
for o in outputs :
generated_text = o . outputs [ 0 ] . text
print ( generated_text )
2024-09-07 16:38:23 +08:00
def run_chat ( model : str , question : str , image_urls : List [ str ] ) :
2024-09-22 06:56:20 -06:00
req_data = model_example_map [ model ] ( question , image_urls )
2024-09-07 16:38:23 +08:00
sampling_params = SamplingParams ( temperature = 0.0 ,
max_tokens = 128 ,
2024-09-22 06:56:20 -06:00
stop_token_ids = req_data . stop_token_ids )
outputs = req_data . llm . chat (
2024-09-12 11:10:54 -06:00
[ {
" role " :
" user " ,
" content " : [
{
" type " : " text " ,
" text " : question ,
2024-09-05 18:51:53 +08:00
} ,
2024-09-12 11:10:54 -06:00
* ( {
" type " : " image_url " ,
" image_url " : {
" url " : image_url
} ,
} for image_url in image_urls ) ,
] ,
} ] ,
sampling_params = sampling_params ,
2024-09-22 06:56:20 -06:00
chat_template = req_data . chat_template ,
2024-09-12 11:10:54 -06:00
)
2024-09-05 18:51:53 +08:00
for o in outputs :
generated_text = o . outputs [ 0 ] . text
print ( generated_text )
def main ( args : Namespace ) :
2024-09-07 16:38:23 +08:00
model = args . model_type
2024-09-05 18:51:53 +08:00
method = args . method
if method == " generate " :
2024-09-07 16:38:23 +08:00
run_generate ( model , QUESTION , IMAGE_URLS )
2024-09-05 18:51:53 +08:00
elif method == " chat " :
2024-09-07 16:38:23 +08:00
run_chat ( model , QUESTION , IMAGE_URLS )
2024-09-05 18:51:53 +08:00
else :
raise ValueError ( f " Invalid method: { method } " )
if __name__ == " __main__ " :
parser = FlexibleArgumentParser (
description = ' Demo on using vLLM for offline inference with '
2024-10-23 11:35:29 +08:00
' vision language models that support multi-image input for text '
' generation ' )
2024-09-07 16:38:23 +08:00
parser . add_argument ( ' --model-type ' ,
' -m ' ,
type = str ,
default = " phi3_v " ,
choices = model_example_map . keys ( ) ,
help = ' Huggingface " model_type " . ' )
2024-09-05 18:51:53 +08:00
parser . add_argument ( " --method " ,
type = str ,
default = " generate " ,
choices = [ " generate " , " chat " ] ,
help = " The method to run in `vllm.LLM`. " )
args = parser . parse_args ( )
main ( args )