[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)

This commit is contained in:
lkchen 2025-03-04 07:43:59 -08:00 committed by GitHub
parent c8525f06fc
commit b3cf368d79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 249 additions and 150 deletions

View File

@ -21,7 +21,7 @@ from vllm.utils import FlexibleArgumentParser
# Aria
def run_aria(question: str, modality: str):
def run_aria(questions: list[str], modality: str):
assert modality == "image"
model_name = "rhymes-ai/Aria"
@ -32,41 +32,42 @@ def run_aria(question: str, modality: str):
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n")
for question in questions]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# BLIP-2
def run_blip2(question: str, modality: str):
def run_blip2(questions: list[str], modality: str):
assert modality == "image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
prompts = [f"Question: {question} Answer:" for question in questions]
llm = LLM(model="Salesforce/blip2-opt-2.7b",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Chameleon
def run_chameleon(question: str, modality: str):
def run_chameleon(questions: list[str], modality: str):
assert modality == "image"
prompt = f"{question}<image>"
prompts = [f"{question}<image>" for question in questions]
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Deepseek-VL2
def run_deepseek_vl2(question: str, modality: str):
def run_deepseek_vl2(questions: list[str], modality: str):
assert modality == "image"
model_name = "deepseek-ai/deepseek-vl2-tiny"
@ -77,9 +78,12 @@ def run_deepseek_vl2(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})
prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
prompts = [
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
for question in questions
]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Florence2
@ -99,20 +103,20 @@ def run_florence2(question: str, modality: str):
# Fuyu
def run_fuyu(question: str, modality: str):
def run_fuyu(questions: list[str], modality: str):
assert modality == "image"
prompt = f"{question}\n"
prompts = [f"{question}\n" for question in questions]
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# GLM-4v
def run_glm4v(question: str, modality: str):
def run_glm4v(questions: list[str], modality: str):
assert modality == "image"
model_name = "THUDM/glm-4v-9b"
@ -124,15 +128,17 @@ def run_glm4v(question: str, modality: str):
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>"
prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>" for question in questions
]
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
def run_h2ovl(questions: list[str], modality: str):
assert modality == "image"
model_name = "h2oai/h2ovl-mississippi-800m"
@ -146,19 +152,24 @@ def run_h2ovl(question: str, modality: str):
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
prompts = [
tokenizer.apply_chat_template([{
'role': 'user',
'content': f"<image>\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
for question in questions
]
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Idefics3-8B-Llama3
def run_idefics3(question: str, modality: str):
def run_idefics3(questions: list[str], modality: str):
assert modality == "image"
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
@ -176,15 +187,15 @@ def run_idefics3(question: str, modality: str):
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = (
prompts = [(
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
)
) for question in questions]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# InternVL
def run_internvl(question: str, modality: str):
def run_internvl(questions: list[str], modality: str):
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
@ -198,10 +209,15 @@ def run_internvl(question: str, modality: str):
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
prompts = [
tokenizer.apply_chat_template([{
'role': 'user',
'content': f"<image>\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
for question in questions
]
# Stop tokens for InternVL
# models variants may have different stop tokens
@ -209,71 +225,82 @@ def run_internvl(question: str, modality: str):
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# LLaVA-1.5
def run_llava(question: str, modality: str):
def run_llava(questions: list[str], modality: str):
assert modality == "image"
prompt = f"USER: <image>\n{question}\nASSISTANT:"
prompts = [
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
]
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question: str, modality: str):
def run_llava_next(questions: list[str], modality: str):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]"
prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(question: str, modality: str):
def run_llava_next_video(questions: list[str], modality: str):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:"
prompts = [
f"USER: <video>\n{question} ASSISTANT:" for question in questions
]
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# LLaVA-OneVision
def run_llava_onevision(question: str, modality: str):
def run_llava_onevision(questions: list[str], modality: str):
if modality == "video":
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n"
prompts = [
f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
]
elif modality == "image":
prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n"
prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
]
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Mantis
def run_mantis(question: str, modality: str):
def run_mantis(questions: list[str], modality: str):
assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompt = llama3_template.format(f"{question}\n<image>")
prompts = [
llama3_template.format(f"{question}\n<image>")
for question in questions
]
llm = LLM(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
@ -282,11 +309,11 @@ def run_mantis(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# MiniCPM-V
def run_minicpmv_base(question: str, modality: str, model_name):
def run_minicpmv_base(questions: list[str], modality: str, model_name):
assert modality in ["image", "video"]
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
@ -333,26 +360,28 @@ def run_minicpmv_base(question: str, modality: str, model_name):
"video": "(<video>./</video>)",
}
messages = [{
prompts = [
tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f'{modality_placeholder[modality]}\n{question}'
}]
prompt = tokenizer.apply_chat_template(messages,
'content': f"{modality_placeholder[modality]}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
return llm, prompt, stop_token_ids
add_generation_prompt=True) for question in questions
]
return llm, prompts, stop_token_ids
def run_minicpmo(question: str, modality: str):
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6")
def run_minicpmo(questions: list[str], modality: str):
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
def run_minicpmv(question: str, modality: str):
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6")
def run_minicpmv(questions: list[str], modality: str):
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
# LLama 3.2
def run_mllama(question: str, modality: str):
def run_mllama(questions: list[str], modality: str):
assert modality == "image"
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@ -379,16 +408,16 @@ def run_mllama(question: str, modality: str):
"type": "text",
"text": f"{question}"
}]
}]
prompt = tokenizer.apply_chat_template(messages,
} for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Molmo
def run_molmo(question, modality):
def run_molmo(questions: list[str], modality: str):
assert modality == "image"
model_name = "allenai/Molmo-7B-D-0924"
@ -400,13 +429,16 @@ def run_molmo(question, modality):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = question
prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# NVLM-D
def run_nvlm_d(question: str, modality: str):
def run_nvlm_d(questions: list[str], modality: str):
assert modality == "image"
model_name = "nvidia/NVLM-D-72B"
@ -422,12 +454,15 @@ def run_nvlm_d(question: str, modality: str):
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
messages = [{
'role': 'user',
'content': f"<image>\n{question}"
} for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# PaliGemma
@ -435,7 +470,7 @@ def run_paligemma(question: str, modality: str):
assert modality == "image"
# PaliGemma has special prompt format for VQA
prompt = "caption en"
prompt = ["caption en"]
llm = LLM(model="google/paligemma-3b-mix-224",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
@ -447,7 +482,7 @@ def run_paligemma2(question: str, modality: str):
assert modality == "image"
# PaliGemma 2 has special prompt format for VQA
prompt = "caption en"
prompt = ["caption en"]
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
@ -455,10 +490,13 @@ def run_paligemma2(question: str, modality: str):
# Phi-3-Vision
def run_phi3v(question: str, modality: str):
def run_phi3v(questions: list[str], modality: str):
assert modality == "image"
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
prompts = [
f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
for question in questions
]
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
@ -482,11 +520,11 @@ def run_phi3v(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Pixtral HF-format
def run_pixtral_hf(question: str, modality: str):
def run_pixtral_hf(questions: list[str], modality: str):
assert modality == "image"
model_name = "mistral-community/pixtral-12b"
@ -499,13 +537,13 @@ def run_pixtral_hf(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = f"<s>[INST]{question}\n[IMG][/INST]"
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Qwen
def run_qwen_vl(question: str, modality: str):
def run_qwen_vl(questions: list[str], modality: str):
assert modality == "image"
llm = LLM(
@ -517,13 +555,13 @@ def run_qwen_vl(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = f"{question}Picture 1: <img></img>\n"
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Qwen2-VL
def run_qwen2_vl(question: str, modality: str):
def run_qwen2_vl(questions: list[str], modality: str):
model_name = "Qwen/Qwen2-VL-7B-Instruct"
@ -544,16 +582,18 @@ def run_qwen2_vl(question: str, modality: str):
elif modality == "video":
placeholder = "<|video_pad|>"
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
"<|im_start|>assistant\n") for question in questions
]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
# Qwen2.5-VL
def run_qwen2_5_vl(question: str, modality: str):
def run_qwen2_5_vl(questions: list[str], modality: str):
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
@ -574,12 +614,14 @@ def run_qwen2_5_vl(question: str, modality: str):
elif modality == "video":
placeholder = "<|video_pad|>"
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
"<|im_start|>assistant\n") for question in questions
]
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompts, stop_token_ids
model_example_map = {
@ -624,29 +666,35 @@ def get_multi_modal_input(args):
# Input image and question
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
img_question = "What is the content of this image?"
img_questions = [
"What is the content of this image?",
"Describe the content of this image in detail.",
"What's in the image?",
"Where is this image taken?",
]
return {
"data": image,
"question": img_question,
"questions": img_questions,
}
if args.modality == "video":
# Input video and question
video = VideoAsset(name="sample_demo_1.mp4",
num_frames=args.num_frames).np_ndarrays
vid_question = "Why is this video funny?"
vid_questions = ["Why is this video funny?"]
return {
"data": video,
"question": vid_question,
"questions": vid_questions,
}
msg = f"Modality {args.modality} is not supported."
raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
def apply_image_repeat(image_repeat_prob, num_prompts, data,
prompts: list[str], modality):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
@ -666,7 +714,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
cur_image.putpixel((0, 0), new_val)
inputs.append({
"prompt": prompt,
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {
modality: cur_image
}
@ -683,9 +731,14 @@ def main(args):
modality = args.modality
mm_input = get_multi_modal_input(args)
data = mm_input["data"]
question = mm_input["question"]
questions = mm_input["questions"]
llm, prompt, stop_token_ids = model_example_map[model](question, modality)
llm, prompts, stop_token_ids = model_example_map[model](questions,
modality)
# Don't want to check the flag multiple times, so just hijack `prompts`.
prompts = prompts if args.use_different_prompt_per_request else [
prompts[0]
]
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
@ -697,27 +750,26 @@ def main(args):
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"prompt": prompts[0],
"multi_modal_data": {
modality: data
},
}
else:
# Batch inference
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob,
args.num_prompts, data, prompt,
args.num_prompts, data, prompts,
modality)
else:
# Use the same image for all prompts
inputs = [{
"prompt": prompt,
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
} for i in range(args.num_prompts)]
if args.time_generate:
import time
@ -775,5 +827,11 @@ if __name__ == "__main__":
action='store_true',
help='If True, then print the total generate() call time')
parser.add_argument(
'--use-different-prompt-per-request',
action='store_true',
help='If True, then use different prompt (with the same multi-modal '
'data) for each request.')
args = parser.parse_args()
main(args)

View File

@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_projection(query_output)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from torch import Tensor
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
T = TypeVar("T", default="NestedTensors")
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
@runtime_checkable
@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
def get_multimodal_embeddings(self, **kwargs) -> T:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol):
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
) -> Tensor:
...
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
@ -210,7 +210,7 @@ class SupportsPP(Protocol):
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
@ -237,7 +237,7 @@ class _SupportsPPType(Protocol):
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
...

View File

@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return None

View File

@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return [
nested_embeds = [
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_input["embed_is_patch"],
)
]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings(
self,

View File

@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:

View File

@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None

View File

@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"])
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

View File

@ -476,7 +476,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return result
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None

View File

@ -692,7 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
)
return decoder_outputs
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)