[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
This commit is contained in:
parent
c8525f06fc
commit
b3cf368d79
@ -21,7 +21,7 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
|
|
||||||
|
|
||||||
# Aria
|
# Aria
|
||||||
def run_aria(question: str, modality: str):
|
def run_aria(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
model_name = "rhymes-ai/Aria"
|
model_name = "rhymes-ai/Aria"
|
||||||
|
|
||||||
@ -32,41 +32,42 @@ def run_aria(question: str, modality: str):
|
|||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||||
"<|im_end|>\n<|im_start|>assistant\n")
|
"<|im_end|>\n<|im_start|>assistant\n")
|
||||||
|
for question in questions]
|
||||||
|
|
||||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# BLIP-2
|
# BLIP-2
|
||||||
def run_blip2(question: str, modality: str):
|
def run_blip2(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
||||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||||
prompt = f"Question: {question} Answer:"
|
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||||
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
llm = LLM(model="Salesforce/blip2-opt-2.7b",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Chameleon
|
# Chameleon
|
||||||
def run_chameleon(question: str, modality: str):
|
def run_chameleon(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"{question}<image>"
|
prompts = [f"{question}<image>" for question in questions]
|
||||||
llm = LLM(model="facebook/chameleon-7b",
|
llm = LLM(model="facebook/chameleon-7b",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Deepseek-VL2
|
# Deepseek-VL2
|
||||||
def run_deepseek_vl2(question: str, modality: str):
|
def run_deepseek_vl2(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||||
@ -77,9 +78,12 @@ def run_deepseek_vl2(question: str, modality: str):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
|
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
|
||||||
|
|
||||||
prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
|
prompts = [
|
||||||
|
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Florence2
|
# Florence2
|
||||||
@ -99,20 +103,20 @@ def run_florence2(question: str, modality: str):
|
|||||||
|
|
||||||
|
|
||||||
# Fuyu
|
# Fuyu
|
||||||
def run_fuyu(question: str, modality: str):
|
def run_fuyu(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"{question}\n"
|
prompts = [f"{question}\n" for question in questions]
|
||||||
llm = LLM(model="adept/fuyu-8b",
|
llm = LLM(model="adept/fuyu-8b",
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# GLM-4v
|
# GLM-4v
|
||||||
def run_glm4v(question: str, modality: str):
|
def run_glm4v(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
model_name = "THUDM/glm-4v-9b"
|
model_name = "THUDM/glm-4v-9b"
|
||||||
|
|
||||||
@ -124,15 +128,17 @@ def run_glm4v(question: str, modality: str):
|
|||||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
prompts = [
|
||||||
{question}<|assistant|>"
|
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||||
|
{question}<|assistant|>" for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
stop_token_ids = [151329, 151336, 151338]
|
stop_token_ids = [151329, 151336, 151338]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# H2OVL-Mississippi
|
# H2OVL-Mississippi
|
||||||
def run_h2ovl(question: str, modality: str):
|
def run_h2ovl(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "h2oai/h2ovl-mississippi-800m"
|
model_name = "h2oai/h2ovl-mississippi-800m"
|
||||||
@ -146,19 +152,24 @@ def run_h2ovl(question: str, modality: str):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
prompts = [
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
tokenizer.apply_chat_template([{
|
||||||
|
'role': 'user',
|
||||||
|
'content': f"<image>\n{question}"
|
||||||
|
}],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
# Stop tokens for H2OVL-Mississippi
|
# Stop tokens for H2OVL-Mississippi
|
||||||
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
|
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
|
||||||
stop_token_ids = [tokenizer.eos_token_id]
|
stop_token_ids = [tokenizer.eos_token_id]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Idefics3-8B-Llama3
|
# Idefics3-8B-Llama3
|
||||||
def run_idefics3(question: str, modality: str):
|
def run_idefics3(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||||
|
|
||||||
@ -176,15 +187,15 @@ def run_idefics3(question: str, modality: str):
|
|||||||
},
|
},
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
prompt = (
|
prompts = [(
|
||||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||||
)
|
) for question in questions]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# InternVL
|
# InternVL
|
||||||
def run_internvl(question: str, modality: str):
|
def run_internvl(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "OpenGVLab/InternVL2-2B"
|
model_name = "OpenGVLab/InternVL2-2B"
|
||||||
@ -198,10 +209,15 @@ def run_internvl(question: str, modality: str):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
prompts = [
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
tokenizer.apply_chat_template([{
|
||||||
|
'role': 'user',
|
||||||
|
'content': f"<image>\n{question}"
|
||||||
|
}],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
# Stop tokens for InternVL
|
# Stop tokens for InternVL
|
||||||
# models variants may have different stop tokens
|
# models variants may have different stop tokens
|
||||||
@ -209,71 +225,82 @@ def run_internvl(question: str, modality: str):
|
|||||||
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
||||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-1.5
|
# LLaVA-1.5
|
||||||
def run_llava(question: str, modality: str):
|
def run_llava(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
prompts = [
|
||||||
|
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-1.6/LLaVA-NeXT
|
# LLaVA-1.6/LLaVA-NeXT
|
||||||
def run_llava_next(question: str, modality: str):
|
def run_llava_next(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
|
||||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LlaVA-NeXT-Video
|
# LlaVA-NeXT-Video
|
||||||
# Currently only support for video input
|
# Currently only support for video input
|
||||||
def run_llava_next_video(question: str, modality: str):
|
def run_llava_next_video(questions: list[str], modality: str):
|
||||||
assert modality == "video"
|
assert modality == "video"
|
||||||
|
|
||||||
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
prompts = [
|
||||||
|
f"USER: <video>\n{question} ASSISTANT:" for question in questions
|
||||||
|
]
|
||||||
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-OneVision
|
# LLaVA-OneVision
|
||||||
def run_llava_onevision(question: str, modality: str):
|
def run_llava_onevision(questions: list[str], modality: str):
|
||||||
|
|
||||||
if modality == "video":
|
if modality == "video":
|
||||||
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
|
prompts = [
|
||||||
<|im_start|>assistant\n"
|
f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||||
|
<|im_start|>assistant\n" for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
elif modality == "image":
|
elif modality == "image":
|
||||||
prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
|
prompts = [
|
||||||
<|im_start|>assistant\n"
|
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||||
|
<|im_start|>assistant\n" for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Mantis
|
# Mantis
|
||||||
def run_mantis(question: str, modality: str):
|
def run_mantis(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
||||||
prompt = llama3_template.format(f"{question}\n<image>")
|
prompts = [
|
||||||
|
llama3_template.format(f"{question}\n<image>")
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
@ -282,11 +309,11 @@ def run_mantis(question: str, modality: str):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
stop_token_ids = [128009]
|
stop_token_ids = [128009]
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# MiniCPM-V
|
# MiniCPM-V
|
||||||
def run_minicpmv_base(question: str, modality: str, model_name):
|
def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||||
assert modality in ["image", "video"]
|
assert modality in ["image", "video"]
|
||||||
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
|
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
|
||||||
|
|
||||||
@ -333,26 +360,28 @@ def run_minicpmv_base(question: str, modality: str, model_name):
|
|||||||
"video": "(<video>./</video>)",
|
"video": "(<video>./</video>)",
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = [{
|
prompts = [
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': f'{modality_placeholder[modality]}\n{question}'
|
'content': f"{modality_placeholder[modality]}\n{question}"
|
||||||
}]
|
}],
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True) for question in questions
|
||||||
return llm, prompt, stop_token_ids
|
]
|
||||||
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
def run_minicpmo(question: str, modality: str):
|
def run_minicpmo(questions: list[str], modality: str):
|
||||||
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6")
|
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
|
||||||
|
|
||||||
|
|
||||||
def run_minicpmv(question: str, modality: str):
|
def run_minicpmv(questions: list[str], modality: str):
|
||||||
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6")
|
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
|
||||||
|
|
||||||
|
|
||||||
# LLama 3.2
|
# LLama 3.2
|
||||||
def run_mllama(question: str, modality: str):
|
def run_mllama(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
@ -379,16 +408,16 @@ def run_mllama(question: str, modality: str):
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"{question}"
|
"text": f"{question}"
|
||||||
}]
|
}]
|
||||||
}]
|
} for question in questions]
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tokenize=False)
|
tokenize=False)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Molmo
|
# Molmo
|
||||||
def run_molmo(question, modality):
|
def run_molmo(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "allenai/Molmo-7B-D-0924"
|
model_name = "allenai/Molmo-7B-D-0924"
|
||||||
@ -400,13 +429,16 @@ def run_molmo(question, modality):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = question
|
prompts = [
|
||||||
|
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||||
|
<|im_start|>assistant\n" for question in questions
|
||||||
|
]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# NVLM-D
|
# NVLM-D
|
||||||
def run_nvlm_d(question: str, modality: str):
|
def run_nvlm_d(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "nvidia/NVLM-D-72B"
|
model_name = "nvidia/NVLM-D-72B"
|
||||||
@ -422,12 +454,15 @@ def run_nvlm_d(question: str, modality: str):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
messages = [{
|
||||||
prompt = tokenizer.apply_chat_template(messages,
|
'role': 'user',
|
||||||
|
'content': f"<image>\n{question}"
|
||||||
|
} for question in questions]
|
||||||
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# PaliGemma
|
# PaliGemma
|
||||||
@ -435,7 +470,7 @@ def run_paligemma(question: str, modality: str):
|
|||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
# PaliGemma has special prompt format for VQA
|
# PaliGemma has special prompt format for VQA
|
||||||
prompt = "caption en"
|
prompt = ["caption en"]
|
||||||
llm = LLM(model="google/paligemma-3b-mix-224",
|
llm = LLM(model="google/paligemma-3b-mix-224",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
@ -447,7 +482,7 @@ def run_paligemma2(question: str, modality: str):
|
|||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
# PaliGemma 2 has special prompt format for VQA
|
# PaliGemma 2 has special prompt format for VQA
|
||||||
prompt = "caption en"
|
prompt = ["caption en"]
|
||||||
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
@ -455,10 +490,13 @@ def run_paligemma2(question: str, modality: str):
|
|||||||
|
|
||||||
|
|
||||||
# Phi-3-Vision
|
# Phi-3-Vision
|
||||||
def run_phi3v(question: str, modality: str):
|
def run_phi3v(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
|
prompts = [
|
||||||
|
f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
# num_crops is an override kwarg to the multimodal image processor;
|
# num_crops is an override kwarg to the multimodal image processor;
|
||||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||||
@ -482,11 +520,11 @@ def run_phi3v(question: str, modality: str):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Pixtral HF-format
|
# Pixtral HF-format
|
||||||
def run_pixtral_hf(question: str, modality: str):
|
def run_pixtral_hf(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "mistral-community/pixtral-12b"
|
model_name = "mistral-community/pixtral-12b"
|
||||||
@ -499,13 +537,13 @@ def run_pixtral_hf(question: str, modality: str):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
|
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Qwen
|
# Qwen
|
||||||
def run_qwen_vl(question: str, modality: str):
|
def run_qwen_vl(questions: list[str], modality: str):
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -517,13 +555,13 @@ def run_qwen_vl(question: str, modality: str):
|
|||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"{question}Picture 1: <img></img>\n"
|
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Qwen2-VL
|
# Qwen2-VL
|
||||||
def run_qwen2_vl(question: str, modality: str):
|
def run_qwen2_vl(questions: list[str], modality: str):
|
||||||
|
|
||||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
|
||||||
@ -544,16 +582,18 @@ def run_qwen2_vl(question: str, modality: str):
|
|||||||
elif modality == "video":
|
elif modality == "video":
|
||||||
placeholder = "<|video_pad|>"
|
placeholder = "<|video_pad|>"
|
||||||
|
|
||||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
prompts = [
|
||||||
|
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||||
f"{question}<|im_end|>\n"
|
f"{question}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\n")
|
"<|im_start|>assistant\n") for question in questions
|
||||||
|
]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
# Qwen2.5-VL
|
# Qwen2.5-VL
|
||||||
def run_qwen2_5_vl(question: str, modality: str):
|
def run_qwen2_5_vl(questions: list[str], modality: str):
|
||||||
|
|
||||||
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
@ -574,12 +614,14 @@ def run_qwen2_5_vl(question: str, modality: str):
|
|||||||
elif modality == "video":
|
elif modality == "video":
|
||||||
placeholder = "<|video_pad|>"
|
placeholder = "<|video_pad|>"
|
||||||
|
|
||||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
prompts = [
|
||||||
|
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||||
f"{question}<|im_end|>\n"
|
f"{question}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\n")
|
"<|im_start|>assistant\n") for question in questions
|
||||||
|
]
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompts, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
@ -624,29 +666,35 @@ def get_multi_modal_input(args):
|
|||||||
# Input image and question
|
# Input image and question
|
||||||
image = ImageAsset("cherry_blossom") \
|
image = ImageAsset("cherry_blossom") \
|
||||||
.pil_image.convert("RGB")
|
.pil_image.convert("RGB")
|
||||||
img_question = "What is the content of this image?"
|
img_questions = [
|
||||||
|
"What is the content of this image?",
|
||||||
|
"Describe the content of this image in detail.",
|
||||||
|
"What's in the image?",
|
||||||
|
"Where is this image taken?",
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": image,
|
"data": image,
|
||||||
"question": img_question,
|
"questions": img_questions,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.modality == "video":
|
if args.modality == "video":
|
||||||
# Input video and question
|
# Input video and question
|
||||||
video = VideoAsset(name="sample_demo_1.mp4",
|
video = VideoAsset(name="sample_demo_1.mp4",
|
||||||
num_frames=args.num_frames).np_ndarrays
|
num_frames=args.num_frames).np_ndarrays
|
||||||
vid_question = "Why is this video funny?"
|
vid_questions = ["Why is this video funny?"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": video,
|
"data": video,
|
||||||
"question": vid_question,
|
"questions": vid_questions,
|
||||||
}
|
}
|
||||||
|
|
||||||
msg = f"Modality {args.modality} is not supported."
|
msg = f"Modality {args.modality} is not supported."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
|
def apply_image_repeat(image_repeat_prob, num_prompts, data,
|
||||||
|
prompts: list[str], modality):
|
||||||
"""Repeats images with provided probability of "image_repeat_prob".
|
"""Repeats images with provided probability of "image_repeat_prob".
|
||||||
Used to simulate hit/miss for the MM preprocessor cache.
|
Used to simulate hit/miss for the MM preprocessor cache.
|
||||||
"""
|
"""
|
||||||
@ -666,7 +714,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
|
|||||||
cur_image.putpixel((0, 0), new_val)
|
cur_image.putpixel((0, 0), new_val)
|
||||||
|
|
||||||
inputs.append({
|
inputs.append({
|
||||||
"prompt": prompt,
|
"prompt": prompts[i % len(prompts)],
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
modality: cur_image
|
modality: cur_image
|
||||||
}
|
}
|
||||||
@ -683,9 +731,14 @@ def main(args):
|
|||||||
modality = args.modality
|
modality = args.modality
|
||||||
mm_input = get_multi_modal_input(args)
|
mm_input = get_multi_modal_input(args)
|
||||||
data = mm_input["data"]
|
data = mm_input["data"]
|
||||||
question = mm_input["question"]
|
questions = mm_input["questions"]
|
||||||
|
|
||||||
llm, prompt, stop_token_ids = model_example_map[model](question, modality)
|
llm, prompts, stop_token_ids = model_example_map[model](questions,
|
||||||
|
modality)
|
||||||
|
# Don't want to check the flag multiple times, so just hijack `prompts`.
|
||||||
|
prompts = prompts if args.use_different_prompt_per_request else [
|
||||||
|
prompts[0]
|
||||||
|
]
|
||||||
|
|
||||||
# We set temperature to 0.2 so that outputs can be different
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
# even when all prompts are identical when running batch inference.
|
# even when all prompts are identical when running batch inference.
|
||||||
@ -697,27 +750,26 @@ def main(args):
|
|||||||
if args.num_prompts == 1:
|
if args.num_prompts == 1:
|
||||||
# Single inference
|
# Single inference
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt": prompt,
|
"prompt": prompts[0],
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
modality: data
|
modality: data
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
if args.image_repeat_prob is not None:
|
if args.image_repeat_prob is not None:
|
||||||
# Repeat images with specified probability of "image_repeat_prob"
|
# Repeat images with specified probability of "image_repeat_prob"
|
||||||
inputs = apply_image_repeat(args.image_repeat_prob,
|
inputs = apply_image_repeat(args.image_repeat_prob,
|
||||||
args.num_prompts, data, prompt,
|
args.num_prompts, data, prompts,
|
||||||
modality)
|
modality)
|
||||||
else:
|
else:
|
||||||
# Use the same image for all prompts
|
# Use the same image for all prompts
|
||||||
inputs = [{
|
inputs = [{
|
||||||
"prompt": prompt,
|
"prompt": prompts[i % len(prompts)],
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
modality: data
|
modality: data
|
||||||
},
|
},
|
||||||
} for _ in range(args.num_prompts)]
|
} for i in range(args.num_prompts)]
|
||||||
|
|
||||||
if args.time_generate:
|
if args.time_generate:
|
||||||
import time
|
import time
|
||||||
@ -775,5 +827,11 @@ if __name__ == "__main__":
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='If True, then print the total generate() call time')
|
help='If True, then print the total generate() call time')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--use-different-prompt-per-request',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, then use different prompt (with the same multi-modal '
|
||||||
|
'data) for each request.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return self.language_projection(query_output)
|
return self.language_projection(query_output)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return self._pixel_values_to_embedding(
|
return self._pixel_values_to_embedding(
|
||||||
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
|
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
pixel_values = image_input["data"]
|
pixel_values = image_input["data"]
|
||||||
return self._encode_image(pixel_values)
|
return self._encode_image(pixel_values)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
""" PyTorch Fuyu model."""
|
""" PyTorch Fuyu model."""
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict
|
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_patches_flat)
|
image_patches_flat)
|
||||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
|
|
||||||
return self.transformer.vision(pixel_values)
|
return self.transformer.vision(pixel_values)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
||||||
self.sampler = get_sampler()
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self.model._parse_and_validate_image_input(**kwargs)
|
image_input = self.model._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
|
|||||||
Protocol, Type, Union, overload, runtime_checkable)
|
Protocol, Type, Union, overload, runtime_checkable)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
from typing_extensions import TypeIs, TypeVar
|
from typing_extensions import TypeIs, TypeVar
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.multimodal.inputs import NestedTensors # noqa: F401
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", default="NestedTensors")
|
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
|
|||||||
MRO of your model class.
|
MRO of your model class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
|
def get_multimodal_embeddings(self, **kwargs) -> T:
|
||||||
"""
|
"""
|
||||||
Returns multimodal embeddings generated from multimodal kwargs
|
Returns multimodal embeddings generated from multimodal kwargs
|
||||||
to be merged with text embeddings.
|
to be merged with text embeddings.
|
||||||
@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol):
|
|||||||
@overload
|
@overload
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: Tensor,
|
||||||
multimodal_embeddings: Optional[T] = None,
|
multimodal_embeddings: Optional[T] = None,
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||||
) -> torch.Tensor:
|
) -> Tensor:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: Tensor,
|
||||||
multimodal_embeddings: Optional[T] = None,
|
multimodal_embeddings: Optional[T] = None,
|
||||||
) -> torch.Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the input embeddings merged from the text embeddings from
|
Returns the input embeddings merged from the text embeddings from
|
||||||
input_ids and the multimodal embeddings generated from multimodal
|
input_ids and the multimodal embeddings generated from multimodal
|
||||||
@ -210,7 +210,7 @@ class SupportsPP(Protocol):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
intermediate_tensors: Optional["IntermediateTensors"],
|
intermediate_tensors: Optional["IntermediateTensors"],
|
||||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
) -> Union[Tensor, "IntermediateTensors"]:
|
||||||
"""
|
"""
|
||||||
Accept :class:`IntermediateTensors` when PP rank > 0.
|
Accept :class:`IntermediateTensors` when PP rank > 0.
|
||||||
|
|
||||||
@ -237,7 +237,7 @@ class _SupportsPPType(Protocol):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
intermediate_tensors: Optional["IntermediateTensors"],
|
intermediate_tensors: Optional["IntermediateTensors"],
|
||||||
) -> Union[torch.Tensor, "IntermediateTensors"]:
|
) -> Union[Tensor, "IntermediateTensors"]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
else:
|
else:
|
||||||
self.visual_token_mask = None
|
self.visual_token_mask = None
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_features = self._process_image_pixels(image_input)
|
image_features = self._process_image_pixels(image_input)
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported type of video input {type(video_pixels)}")
|
f"Unsupported type of video input {type(video_pixels)}")
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
if video_input is None:
|
if video_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptInsertion, PromptUpdate)
|
PromptInsertion, PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import JSONTree, json_map_leaves
|
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||||
|
|
||||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
return embeds_in_batch
|
return embeds_in_batch
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_features = self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
return [
|
nested_embeds = [
|
||||||
self._get_mm_embeds(*args) for args in zip(
|
self._get_mm_embeds(*args) for args in zip(
|
||||||
image_features,
|
image_features,
|
||||||
image_input["feat_is_patch"],
|
image_input["feat_is_patch"],
|
||||||
@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
image_input["embed_is_patch"],
|
image_input["embed_is_patch"],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
return flatten_2d_lists(nested_embeds)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return get_sampler()
|
return get_sampler()
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input, image_tokens = self._parse_and_validate_image_input(
|
image_input, image_tokens = self._parse_and_validate_image_input(
|
||||||
**kwargs)
|
**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
|
@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return torch.split(masked_audio_features,
|
return torch.split(masked_audio_features,
|
||||||
audio_output_lengths.flatten().tolist())
|
audio_output_lengths.flatten().tolist())
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
if audio_input is None:
|
if audio_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
return self.transformer.visual(image_input["data"])
|
return self.transformer.visual(image_input["data"])
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -476,7 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
if audio_input is None:
|
if audio_input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -692,7 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
)
|
)
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
# TODO: This method does not obey the interface for SupportsMultiModal.
|
||||||
# Refactor this once encoder/decoder support is implemented in V1.
|
# Refactor this once encoder/decoder support is implemented in V1.
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user