[Misc] Use NamedTuple in Multi-image example (#8705)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
06ed2815e2
commit
8ca5051b9a
@ -4,8 +4,9 @@ multi-image input on vision language models, using the chat template defined
|
|||||||
by the model.
|
by the model.
|
||||||
"""
|
"""
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import List
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
from transformers import AutoProcessor, AutoTokenizer
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
@ -19,7 +20,15 @@ IMAGE_URLS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_qwenvl_chat(question: str, image_urls: List[str]):
|
class ModelRequestData(NamedTuple):
|
||||||
|
llm: LLM
|
||||||
|
prompt: str
|
||||||
|
stop_token_ids: Optional[List[str]]
|
||||||
|
image_data: List[Image]
|
||||||
|
chat_template: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||||
model_name = "Qwen/Qwen-VL-Chat"
|
model_name = "Qwen/Qwen-VL-Chat"
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
|
|||||||
|
|
||||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
return llm, prompt, stop_token_ids, None, chat_template
|
return ModelRequestData(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
chat_template=chat_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_phi3v(question: str, image_urls: List[str]):
|
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="microsoft/Phi-3.5-vision-instruct",
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
|
|||||||
for i, _ in enumerate(image_urls, start=1))
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids, None, None
|
|
||||||
|
return ModelRequestData(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
chat_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_internvl(question: str, image_urls: List[str]):
|
def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||||
model_name = "OpenGVLab/InternVL2-2B"
|
model_name = "OpenGVLab/InternVL2-2B"
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
|
|||||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
|
|
||||||
return llm, prompt, stop_token_ids, None, None
|
return ModelRequestData(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
chat_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_qwen2_vl(question, image_urls: List[str]):
|
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||||
try:
|
try:
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
|
|||||||
else:
|
else:
|
||||||
image_data, _ = process_vision_info(messages)
|
image_data, _ = process_vision_info(messages)
|
||||||
|
|
||||||
return llm, prompt, stop_token_ids, image_data, None
|
return ModelRequestData(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
image_data=image_data,
|
||||||
|
chat_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
@ -155,20 +189,17 @@ model_example_map = {
|
|||||||
|
|
||||||
|
|
||||||
def run_generate(model, question: str, image_urls: List[str]):
|
def run_generate(model, question: str, image_urls: List[str]):
|
||||||
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
|
req_data = model_example_map[model](question, image_urls)
|
||||||
question, image_urls)
|
|
||||||
if image_data is None:
|
|
||||||
image_data = [fetch_image(url) for url in image_urls]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=128,
|
max_tokens=128,
|
||||||
stop_token_ids=stop_token_ids)
|
stop_token_ids=req_data.stop_token_ids)
|
||||||
|
|
||||||
outputs = llm.generate(
|
outputs = req_data.llm.generate(
|
||||||
{
|
{
|
||||||
"prompt": prompt,
|
"prompt": req_data.prompt,
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"image": image_data
|
"image": req_data.image_data
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||||
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
|
req_data = model_example_map[model](question, image_urls)
|
||||||
question, image_urls)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=128,
|
max_tokens=128,
|
||||||
stop_token_ids=stop_token_ids)
|
stop_token_ids=req_data.stop_token_ids)
|
||||||
outputs = llm.chat(
|
outputs = req_data.llm.chat(
|
||||||
[{
|
[{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
|
|||||||
],
|
],
|
||||||
}],
|
}],
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
chat_template=chat_template,
|
chat_template=req_data.chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user