[Model] Support E5-V (#9576)

This commit is contained in:
Cyrus Leung 2024-10-23 11:35:29 +08:00 committed by GitHub
parent 29061ed9df
commit 831540cf04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 528 additions and 86 deletions

View File

@ -334,6 +334,14 @@ The following modalities are supported depending on the model:
- **V**\ ideo
- **A**\ udio
Any combination of modalities joined by :code:`+` are supported.
- e.g.: :code:`T + I` means that the model supports text-only, image-only, and text-with-image inputs.
On the other hand, modalities separated by :code:`/` are mutually exclusive.
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
.. _supported_vlms:
Text Generation
@ -484,6 +492,12 @@ Multimodal Embedding
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT-based
- T / I
- :code:`royokong/e5-v`
-
- ✅︎
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision-based
- T + I

View File

@ -1,6 +1,6 @@
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
@ -450,7 +450,7 @@ def main(args):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models')
'vision language models for text generation')
parser.add_argument('--model-type',
'-m',
type=str,

View File

@ -1,22 +1,170 @@
"""
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from argparse import Namespace
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
from PIL.Image import Image
from vllm import LLM
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
prompt = "<|image_1|> Represent the given image with the following question: What is in the image" # noqa: E501
# Create an LLM.
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
task="embedding",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={"num_crops": 16},
)
class TextQuery(TypedDict):
modality: Literal["text"]
text: str
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.encode({"prompt": prompt, "multi_modal_data": {"image": image}})
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 3072 floats
class ImageQuery(TypedDict):
modality: Literal["image"]
image: Image
class TextImageQuery(TypedDict):
modality: Literal["text+image"]
text: str
image: Image
QueryModality = Literal["text", "image", "text+image"]
Query = Union[TextQuery, ImageQuery, TextImageQuery]
class ModelRequestData(NamedTuple):
llm: LLM
prompt: str
image: Optional[Image]
def run_e5_v(query: Query):
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
if query["modality"] == "text":
text = query["text"]
prompt = llama3_template.format(
f"{text}\nSummary above sentence in one word: ")
image = None
elif query["modality"] == "image":
prompt = llama3_template.format(
"<image>\nSummary above image in one word: ")
image = query["image"]
else:
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
model="royokong/e5-v",
task="embedding",
max_model_len=4096,
)
return ModelRequestData(
llm=llm,
prompt=prompt,
image=image,
)
def run_vlm2vec(query: Query):
if query["modality"] == "text":
text = query["text"]
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
image = None
elif query["modality"] == "image":
prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501
image = query["image"]
elif query["modality"] == "text+image":
text = query["text"]
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
image = query["image"]
else:
modality = query['modality']
raise ValueError(f"Unsupported query modality: '{modality}'")
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
task="embedding",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
)
return ModelRequestData(
llm=llm,
prompt=prompt,
image=image,
)
def get_query(modality: QueryModality):
if modality == "text":
return TextQuery(modality="text", text="A dog sitting in the grass")
if modality == "image":
return ImageQuery(
modality="image",
image=fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501
),
)
if modality == "text+image":
return TextImageQuery(
modality="text+image",
text="A cat standing in the snow.",
image=fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501
),
)
msg = f"Modality {modality} is not supported."
raise ValueError(msg)
def run_encode(model: str, modality: QueryModality):
query = get_query(modality)
req_data = model_example_map[model](query)
mm_data = {}
if req_data.image is not None:
mm_data["image"] = req_data.image
outputs = req_data.llm.encode({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
for output in outputs:
print(output.outputs.embedding)
def main(args: Namespace):
run_encode(args.model_name, args.modality)
model_example_map = {
"e5_v": run_e5_v,
"vlm2vec": run_vlm2vec,
}
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models for multimodal embedding')
parser.add_argument('--model-name',
'-m',
type=str,
default="vlm2vec",
choices=model_example_map.keys(),
help='The name of the embedding model.')
parser.add_argument('--modality',
type=str,
default="image",
choices=get_args(QueryModality),
help='Modality of the input.')
args = parser.parse_args()
main(args)

View File

@ -1,7 +1,7 @@
"""
This example shows how to use vLLM for running offline inference with
multi-image input on vision language models, using the chat template defined
by the model.
multi-image input on vision language models for text generation,
using the chat template defined by the model.
"""
from argparse import Namespace
from typing import List, NamedTuple, Optional
@ -334,7 +334,8 @@ def main(args: Namespace):
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input')
'vision language models that support multi-image input for text '
'generation')
parser.add_argument('--model-type',
'-m',
type=str,

View File

@ -43,10 +43,12 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
def _read_prompts(filename: str) -> List[str]:
@ -318,12 +320,12 @@ class HfRunner:
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]
if audios is not None and audios[i] is not None:
audio, sr = audios[i]
if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
if audios is not None and (audio_tuple := audios[i]) is not None:
audio, sr = audio_tuple
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
@ -338,7 +340,7 @@ class HfRunner:
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
@ -368,7 +370,7 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
@ -409,7 +411,7 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[List[np.ndarray]] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
@ -488,7 +490,7 @@ class HfRunner:
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
@ -657,15 +659,18 @@ class VllmRunner:
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
@ -837,13 +842,20 @@ class VllmRunner:
returned_outputs.append((token_ids, texts))
return returned_outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def encode(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[List[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.encode(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def __enter__(self):
return self

View File

@ -16,7 +16,8 @@ def check_embeddings_close(
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1)
assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),

View File

@ -0,0 +1,135 @@
from typing import List, Type
import pytest
import torch.nn.functional as F
from transformers import AutoModelForVision2Seq
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
HF_TEXT_PROMPTS = [
# T -> X
llama3_template.format(
"The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501
),
# T -> X
llama3_template.format(
"cherry blossom\nSummary above sentence in one word: "),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
# I -> X
"stop_sign":
llama3_template.format("<image>\nSummary above image in one word: "),
# I -> X
"cherry_blossom":
llama3_template.format("<image>\nSummary above image in one word: "),
})
MODELS = ["royokong/e5-v"]
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
dtype=dtype,
max_model_len=4096,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
# Patch the issue where image_token_id
# exceeds the maximum allowed vocab size
hf_model.model.resize_token_embeddings(
hf_model.model.language_model.vocab_size + 1)
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
all_outputs = []
for inputs in all_inputs:
# Based on: https://huggingface.co/royokong/e5-v
outputs = hf_model.model(
**hf_model.wrap_device(inputs,
device=hf_model.model.device.type),
return_dict=True,
output_hidden_states=True,
)
pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :],
dim=-1)
all_outputs.append(pooled_output.tolist())
hf_outputs = all_outputs
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)

View File

@ -1,42 +1,53 @@
from typing import List, Type
import pytest
import torch.nn.functional as F
from ....conftest import IMAGE_ASSETS
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
HF_TEXT_PROMPTS = [
# T -> X
"Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501
# T -> X
"Retrieve an image of this caption: cherry blossom",
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
# T + I -> X
"stop_sign":
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501
# I -> X
"cherry_blossom":
"<|image_1|> Represent the given image with the following question: What is in the image", # noqa: E501
"<|image_1|> Represent the given image for classification", # noqa: E501
})
MODELS = ["TIGER-Lab/VLM2Vec-Full"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,
with vllm_runner(model, task="embedding", dtype=dtype,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
with hf_runner(model, dtype=dtype) as hf_model:
all_inputs = hf_model.get_inputs(example_prompts)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
all_outputs = []
for inputs in all_inputs:
@ -61,3 +72,53 @@ def test_models(
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]
_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)

View File

@ -13,11 +13,13 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@ -28,8 +30,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@ -312,6 +314,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -605,14 +611,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
inputs_embeds = embed_multimodal(
input_ids,
self.config.image_token_index,
self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input),
)
input_ids = None
else:
inputs_embeds = None
@ -641,6 +645,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -467,8 +467,6 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids = inputs["prompt_token_ids"].copy()
print("prompt_token_ids (old)", prompt_token_ids)
# masked placeholder with image token id
for idx in image_idx:
candidates = _get_image_placeholder_token_id_candidates(model_config,

View File

@ -94,6 +94,7 @@ _EMBEDDING_MODELS = {
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}

View File

@ -1,7 +1,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Tuple, Union, overload)
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Tuple, Union, overload)
import torch
import torch.nn as nn
@ -294,10 +294,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
@ -306,8 +307,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
Note:
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum().item()
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
@ -317,10 +317,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
List[torch.Tensor]]],
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal = input_ids == multimodal_token_id
is_text = ~is_multimodal
text_embeds = get_text_embeds(input_ids[is_text])
multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
merged_embeds = torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
)
merged_embeds[is_text] = text_embeds
return _merge_multimodal_embeddings(
merged_embeds,
is_multimodal,
multimodal_embeds,
)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
class LayerFn(Protocol):
def __call__(self, prefix: str) -> torch.nn.Module: