diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md
index c4894d39..9cbfc329 100644
--- a/docs/source/contributing/model/multimodal.md
+++ b/docs/source/contributing/model/multimodal.md
@@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
)
```
-To assign the vision embeddings to only the image tokens, instead of a string
-you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
+To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`
+with different `full` and `feature` attributes:
```python
hf_config = self.info.get_hf_config()
@@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
- return PromptUpdateDetails.select_token_id(
- image_tokens + [bos_token_id],
- embed_token_id=_IMAGE_TOKEN_ID,
+ return PromptUpdateDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
)
```
@@ -914,9 +914,9 @@ def _get_prompt_updates(
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
- return PromptUpdateDetails.select_token_id(
- image_tokens + [bos_token_id],
- embed_token_id=_IMAGE_TOKEN_ID,
+ return PromptUpdateDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
)
return [
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 316fc3b2..74b4eab9 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -989,6 +989,9 @@ See [this page](#generative-models) for more information on how to use generativ
+ Multiple items can be inputted per text prompt for this modality.
:::{important}
+To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
+`pip install git+https://github.com/huggingface/transformers`.
+
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
:::
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index f33efbab..840892ea 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
- max_num_seqs=2,
+ max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index 242f3398..83ece5d2 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -55,10 +55,7 @@ def server(request, audio_assets):
for key, value in request.param.items()
]
- with RemoteOpenAIServer(MODEL_NAME,
- args,
- env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
- "30"}) as remote_server:
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index b984cd6f..3b34f012 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -167,7 +167,7 @@ VLM_TEST_SETTINGS = {
"cherry_blossom": "What is the season?", # noqa: E501
}),
multi_image_prompt="Describe the two images in detail.", # noqa: E501
- max_model_len=4096,
+ max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py
index 6ebe75f0..ee619d8d 100644
--- a/tests/models/decoder_only/vision_language/test_pixtral.py
+++ b/tests/models/decoder_only/vision_language/test_pixtral.py
@@ -176,8 +176,6 @@ def test_chat(
model,
dtype=dtype,
tokenizer_mode="mistral",
- load_format="mistral",
- config_format="mistral",
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
@@ -200,14 +198,22 @@ def test_chat(
@large_gpu_test(min_gb=48)
-@pytest.mark.parametrize("prompt,expected_ranges",
- [(_create_engine_inputs_hf(IMG_URLS[:1]),
- [PlaceholderRange(offset=11, length=494)]),
- (_create_engine_inputs_hf(IMG_URLS[1:4]), [
- PlaceholderRange(offset=11, length=266),
- PlaceholderRange(offset=277, length=1056),
- PlaceholderRange(offset=1333, length=418)
- ])])
+@pytest.mark.parametrize(
+ "prompt,expected_ranges",
+ [(_create_engine_inputs_hf(IMG_URLS[:1]), [{
+ "offset": 11,
+ "length": 494
+ }]),
+ (_create_engine_inputs_hf(IMG_URLS[1:4]), [{
+ "offset": 11,
+ "length": 266
+ }, {
+ "offset": 277,
+ "length": 1056
+ }, {
+ "offset": 1333,
+ "length": 418
+ }])])
def test_multi_modal_placeholders(vllm_runner, prompt,
expected_ranges: list[PlaceholderRange],
monkeypatch) -> None:
diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py
index b82bfe48..fe56a200 100644
--- a/tests/models/multimodal/processing/test_llava_next.py
+++ b/tests/models/multimodal/processing/test_llava_next.py
@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
- assert first_placeholder.offset == 1
- assert first_placeholder.length == (
+ assert first_placeholder["offset"] == 1
+ assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
except Exception as exc:
diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py
index dcc8dc8d..7cefdd37 100644
--- a/tests/models/multimodal/processing/test_llava_onevision.py
+++ b/tests/models/multimodal/processing/test_llava_onevision.py
@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0]
- assert first_placeholder.offset == 0
- assert first_placeholder.length == len(
+ assert first_placeholder["offset"] == 0
+ assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
except Exception as exc:
failed_size_excs.append((image_size, exc))
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 9996bd2e..39e104a1 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -277,9 +277,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
- extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
- max_transformers_version="4.48", # noqa: E501
- transformers_version_reason="HF model is not compatible."), # noqa: E501
+ extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
trust_remote_code=True),
diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py
index fa9588a0..da112bd7 100644
--- a/tests/multimodal/test_processing.py
+++ b/tests/multimodal/test_processing.py
@@ -785,7 +785,6 @@ def test_find_update_tokens(
item_idx=0,
start_idx=6,
tokens=[32000, 32000],
- is_embed=None,
),
],
"pattern_4": [
@@ -794,7 +793,6 @@ def test_find_update_tokens(
item_idx=0,
start_idx=3,
tokens=[32000],
- is_embed=None,
),
],
}
@@ -809,14 +807,12 @@ def test_find_update_tokens(
item_idx=0,
start_idx=1,
tokens=[32000, 32000],
- is_embed=None,
),
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
tokens=[32000, 32000],
- is_embed=None,
),
],
"pattern_3": [
@@ -825,7 +821,6 @@ def test_find_update_tokens(
item_idx=0,
start_idx=7,
tokens=[1550, 918, 1550],
- is_embed=None,
),
],
# No match for pattern_4 as it has lower priority than pattern_1
@@ -840,14 +835,12 @@ def test_find_update_tokens(
item_idx=0,
start_idx=1,
tokens=[32000, 32000],
- is_embed=None,
),
PlaceholderFeaturesInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
tokens=[32000, 32000],
- is_embed=None,
),
],
"pattern_4": [
@@ -856,7 +849,6 @@ def test_find_update_tokens(
item_idx=0,
start_idx=5,
tokens=[32000],
- is_embed=None,
),
],
"pattern_3": [
@@ -865,7 +857,6 @@ def test_find_update_tokens(
item_idx=0,
start_idx=6,
tokens=[1550, 918, 1550],
- is_embed=None,
),
],
}
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 51836644..8362af24 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
+from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
# disable yapf here as it formats differently than isort such that both fail
@@ -158,10 +158,13 @@ def test_generate_block_hash_extra_keys():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(20)],
- mm_positions=[
- PlaceholderRange(offset=0, length=5),
- PlaceholderRange(offset=10, length=5),
- ],
+ mm_positions=[{
+ "offset": 0,
+ "length": 5
+ }, {
+ "offset": 10,
+ "length": 5
+ }],
mm_hashes=["hash1", "hash2"],
)
@@ -219,10 +222,13 @@ def test_hash_request_tokens(hash_fn):
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
- mm_positions=[
- PlaceholderRange(offset=0, length=3),
- PlaceholderRange(offset=3, length=3),
- ],
+ mm_positions=[{
+ "offset": 0,
+ "length": 3
+ }, {
+ "offset": 3,
+ "length": 3
+ }],
mm_hashes=["hash1", "hash2"],
)
@@ -247,19 +253,25 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
- mm_positions=[
- PlaceholderRange(offset=0, length=3),
- PlaceholderRange(offset=3, length=3),
- ],
+ mm_positions=[{
+ "offset": 0,
+ "length": 3
+ }, {
+ "offset": 3,
+ "length": 3
+ }],
mm_hashes=["hash1", "hash2"],
)
request2 = make_request(
request_id=1,
prompt_token_ids=[_ for _ in range(6)],
- mm_positions=[
- PlaceholderRange(offset=0, length=3),
- PlaceholderRange(offset=3, length=3),
- ],
+ mm_positions=[{
+ "offset": 0,
+ "length": 3
+ }, {
+ "offset": 3,
+ "length": 3
+ }],
mm_hashes=["hash3", "hash2"],
)
block_size = 3
diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py
index 6b68885d..b4bf1d82 100644
--- a/vllm/model_executor/models/aya_vision.py
+++ b/vllm/model_executor/models/aya_vision.py
@@ -27,7 +27,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
- PromptUpdateDetails)
+ encode_tokens)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -35,6 +35,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
class AyaVisionImagePixelInputs(TypedDict):
@@ -50,6 +51,13 @@ class AyaVisionImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class AyaVisionMultiModalProjector(nn.Module):
@@ -127,20 +135,21 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor
-
image_size = self.get_image_size_with_most_features()
+ tokenizer = hf_processor.tokenizer
num_patches = self.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
- max_patches=image_processor.max_patches,
+ max_patches=image_processor.max_patches)
+ image_string = hf_processor._prompt_split_image(num_patches)
+ x = encode_tokens(
+ tokenizer,
+ image_string,
+ add_special_tokens=False,
)
-
- img_patches_per_tile = (hf_processor.img_size //
- hf_processor.patch_size)**2
-
- return num_patches * img_patches_per_tile
+ return len(x)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -212,6 +221,7 @@ class AyaVisionMultiModalProcessor(
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = hf_processor.image_processor
+ hf_config = self.info.get_hf_config()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images :=
mm_data.get("images")) is not None and '' in prompt:
@@ -224,7 +234,6 @@ class AyaVisionMultiModalProcessor(
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
-
num_patches = [
self.info.get_num_patches(
image_width=image_size.width,
@@ -234,6 +243,20 @@ class AyaVisionMultiModalProcessor(
max_patches=image_processor.max_patches)
for image_size in image_sizes
]
+ image_tokens_list = [
+ hf_processor._prompt_split_image(num_patch)
+ for num_patch in num_patches
+ ]
+ tokenizer = self.info.get_tokenizer()
+ image_token_ids = [
+ tokenizer.encode(image_tokens, add_special_tokens=False)
+ for image_tokens in image_tokens_list
+ ]
+ embed_is_patch = [
+ torch.tensor(image_repl_tokens) == hf_config.image_token_index
+ for image_repl_tokens in image_token_ids
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["num_patches"] = torch.tensor(num_patches)
return processed_outputs
@@ -248,6 +271,7 @@ class AyaVisionMultiModalProcessor(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -259,7 +283,6 @@ class AyaVisionMultiModalProcessor(
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
- img_patch_token = hf_processor.img_patch_token
image_processor = hf_processor.image_processor
def get_replacement(item_idx: int):
@@ -271,11 +294,8 @@ class AyaVisionMultiModalProcessor(
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
- max_patches=image_processor.max_patches,
- )
- repl = hf_processor._prompt_split_image(num_patches=num_patches)
-
- return PromptUpdateDetails.select_text(repl, img_patch_token)
+ max_patches=image_processor.max_patches)
+ return hf_processor._prompt_split_image(num_patches=num_patches)
return [
PromptReplacement(
@@ -404,6 +424,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds."
@@ -415,13 +436,18 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True)
-
+ embed_is_patch = flatten_bn(embed_is_patch)
return AyaVisionImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_patches,
+ embed_is_patch=embed_is_patch,
)
def get_multimodal_embeddings(
@@ -429,8 +455,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
-
- return self._process_image_input(image_input, **kwargs)
+ image_features = self._process_image_input(image_input, **kwargs)
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -442,9 +471,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
- multimodal_embeddings=multimodal_embeddings,
- placeholder_token_id=self.config.image_token_index,
- )
+ multimodal_embeddings=select_patch_features(
+ multimodal_embeddings),
+ placeholder_token_id=self.config.image_token_index)
return inputs_embeds
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 3d527cb6..f758c98e 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -162,9 +162,9 @@ class ChameleonMultiModalProcessor(
PromptReplacement(
modality="image",
target=[image_token_id],
- replacement=PromptUpdateDetails.select_token_id(
- [image_start_id] + image_tokens + [image_end_id],
- embed_token_id=image_token_id,
+ replacement=PromptUpdateDetails(
+ full=([image_start_id] + image_tokens + [image_end_id]),
+ features=image_tokens,
),
)
]
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 189b91db..a807b047 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
-from typing import Literal, Optional, Set, Tuple, TypedDict
+from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class FuyuProcessingInfo(BaseProcessingInfo):
@@ -85,7 +94,15 @@ class FuyuProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
- return {"image": self.get_max_image_tokens()}
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ max_ncols, max_nrows = self.get_image_feature_grid_size(
+ image_width=target_width,
+ image_height=target_height,
+ )
+ max_image_tokens = (max_ncols + 1) * max_nrows
+
+ return {"image": max_image_tokens}
def get_image_feature_grid_size(
self,
@@ -111,32 +128,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
nrows = math.ceil(image_height / patch_height)
return ncols, nrows
- def get_num_image_tokens(
- self,
- *,
- image_width: int,
- image_height: int,
- ) -> int:
- ncols, nrows = self.get_image_feature_grid_size(
- image_width=image_width,
- image_height=image_height,
- )
-
- return ncols * nrows
-
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
- def get_max_image_tokens(self) -> int:
- target_width, target_height = self.get_image_size_with_most_features()
-
- return self.get_num_image_tokens(
- image_width=target_width,
- image_height=target_height,
- )
-
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
@@ -196,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0]
+ # get patch grid size for each image
+ embed_is_patch = []
+ for image in images:
+ ncols, nrows = self.info.get_image_feature_grid_size(
+ image_width=image.width,
+ image_height=image.height,
+ )
+
+ mask = torch.tensor(([True] * ncols + [False]) * nrows)
+ embed_is_patch.append(mask)
+
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
return processed_outputs
def _apply_hf_processor_tokens_only(
@@ -215,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
- return dict(image_patches=MultiModalFieldConfig.batched("image"))
+ return dict(image_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
@@ -242,9 +252,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
- return PromptUpdateDetails.select_token_id(
- image_tokens + [bos_token_id],
- embed_token_id=_IMAGE_TOKEN_ID,
+ return PromptUpdateDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
)
return [
@@ -319,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
image_patches_flat = flatten_bn(image_patches)
+ embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
+ embed_is_patch=embed_is_patch,
)
return None
@@ -347,7 +364,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -357,11 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- multimodal_embeddings,
- _IMAGE_TOKEN_ID,
- )
+ input_ids, inputs_embeds,
+ select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index 9552ee1f..bbdea70a 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
- find_mm_placeholders,
+ encode_tokens, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@@ -36,6 +36,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
@@ -53,6 +54,14 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -174,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if processor is None:
processor = self.get_hf_processor()
- boi_token = processor.boi_token
+ image_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
@@ -183,21 +192,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
)
if num_crops == 0:
- image_text = boi_token
+ image_text = image_token
else:
- crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
+ crops_image_tokens = " ".join(image_token
+ for _ in range(num_crops))
image_text = (
- f"Here is the original image {boi_token} and here are some "
+ f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
- repl_full = image_text.replace(boi_token,
+ repl_full = image_text.replace(image_token,
processor.full_image_sequence)
+ repl_features = repl_full.strip("\n")
- tokenizer = processor.tokenizer
- vocab = tokenizer.get_vocab()
- image_token_id = vocab[tokenizer.image_token]
-
- return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
def get_num_image_tokens(
self,
@@ -206,17 +213,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
- if processor is None:
- processor = self.get_hf_processor()
-
- num_crops = self.get_num_crops(
+ tokenizer = self.get_tokenizer()
+ image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
- image_seq_len = processor.image_seq_length
- return (num_crops + 1) * image_seq_len
+ image_repl_tokens = encode_tokens(
+ tokenizer,
+ image_repl.features,
+ add_special_tokens=False,
+ )
+ return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
@@ -292,6 +301,28 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
+ image_repl_features = [
+ self.info.get_image_repl(image_width=size.width,
+ image_height=size.height,
+ processor=hf_processor).features
+ for size in image_sizes
+ ]
+
+ tokenizer = self.info.get_tokenizer()
+ image_repls_feature_tokens = [
+ tokenizer.encode(image_repl, add_special_tokens=False)
+ for image_repl in image_repl_features
+ ]
+
+ vocab = tokenizer.get_vocab()
+ image_token_id = vocab[tokenizer.image_token]
+
+ embed_is_patch = [
+ torch.tensor(image_repl_tokens) == image_token_id
+ for image_repl_tokens in image_repls_feature_tokens
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
@@ -313,6 +344,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -422,7 +454,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
- is_embed=p.is_embed,
) for p in placeholders
]
for modality, placeholders in repls.items()
@@ -541,6 +572,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@@ -554,13 +586,19 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
+ embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
+ embed_is_patch=embed_is_patch,
)
def _image_pixels_to_features(
@@ -597,7 +635,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -609,7 +652,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py
index f975a19a..3b2ad695 100644
--- a/vllm/model_executor/models/h2ovl.py
+++ b/vllm/model_executor/models/h2ovl.py
@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
- return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
def resolve_min_max_num(
self,
diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py
index 347106bc..da4a4434 100644
--- a/vllm/model_executor/models/idefics3.py
+++ b/vllm/model_executor/models/idefics3.py
@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
- PromptUpdateDetails)
+ encode_tokens)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -54,6 +54,7 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
@@ -68,6 +69,14 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -77,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
@@ -258,16 +275,19 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
- if processor is None:
- processor = self.get_hf_processor()
-
- num_patches = self.get_num_patches(
+ tokenizer = self.get_tokenizer()
+ image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
- return num_patches * processor.image_seq_len
+ image_repl_tokens = encode_tokens(
+ tokenizer,
+ image_repl,
+ add_special_tokens=False,
+ )
+ return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
@@ -344,6 +364,28 @@ class Idefics3MultiModalProcessor(
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
+ image_repl_features = [
+ self.info.get_image_repl(image_width=size.width,
+ image_height=size.height,
+ processor=hf_processor)
+ for size in image_sizes
+ ]
+
+ tokenizer = self.info.get_tokenizer()
+ image_repls_feature_tokens = [
+ tokenizer.encode(image_repl, add_special_tokens=False)
+ for image_repl in image_repl_features
+ ]
+
+ vocab = tokenizer.get_vocab()
+ image_token_id = vocab[hf_processor.image_token.content]
+
+ embed_is_patch = [
+ torch.tensor(image_repl_tokens) == image_token_id
+ for image_repl_tokens in image_repls_feature_tokens
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
num_patches = [
self.info.get_num_patches(
image_width=size.width,
@@ -373,6 +415,7 @@ class Idefics3MultiModalProcessor(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -384,22 +427,17 @@ class Idefics3MultiModalProcessor(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
- def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
+ def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
- image_repl = self.info.get_image_repl(
+ return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
- return PromptUpdateDetails.select_text(
- image_repl,
- embed_text=image_token,
- )
-
return [
PromptReplacement(
modality="image",
@@ -637,6 +675,13 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None:
return None
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ embed_is_patch = flatten_bn(embed_is_patch)
+
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
@@ -645,6 +690,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
+ embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
@@ -672,6 +718,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
+ embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
@@ -707,7 +754,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -719,7 +771,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.config.image_token_id,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index cf5608e3..0729f4c7 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -39,6 +39,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
IMG_START = '
'
IMG_END = ''
@@ -59,6 +60,14 @@ class InternVLImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -410,12 +419,24 @@ class BaseInternVLProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]),
}
+ tokenizer = self.tokenizer
+ image_token_id = self.image_token_id
+
+ embed_is_patch = list[torch.Tensor]()
+
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
+ feature_tokens = tokenizer.encode(image_repl.features,
+ add_special_tokens=False)
+
text = [t.replace('', image_repl.full, 1) for t in text]
+ embed_is_patch.append(
+ torch.tensor(feature_tokens) == image_token_id)
+
+ image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
@@ -439,7 +460,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
- return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseInternVLProcessingInfo(BaseProcessingInfo):
@@ -578,6 +599,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@@ -809,6 +831,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@@ -837,14 +860,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
+ embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
+ embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
@@ -890,7 +919,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ if image_input["type"] != "pixel_values":
+ return image_features
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -904,7 +941,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index b34ac38f..45a0bf73 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -32,8 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
- PromptReplacement, PromptUpdate,
- PromptUpdateDetails)
+ PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -43,7 +42,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
-from .vision import get_vision_encoder_info
+from .vision import (get_vision_encoder_info, scatter_patch_features,
+ select_patch_features)
class LlavaImagePixelInputs(TypedDict):
@@ -67,6 +67,14 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -335,6 +343,23 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes)
]
+ hf_config = self.info.get_hf_config()
+ vision_config = hf_config.vision_config
+ assert isinstance(vision_config, PixtralVisionConfig)
+ encoder_info = PixtralHFEncoderInfo(vision_config)
+
+ tile_sizes = [
+ encoder_info.get_patch_grid_size(
+ image_width=pixel_value.shape[-1],
+ image_height=pixel_value.shape[-2],
+ ) for pixel_value in processed_outputs["pixel_values"]
+ ]
+ embed_is_patch = [
+ torch.tensor(([True] * ncols + [False]) * nrows)
+ for ncols, nrows in tile_sizes
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
return processed_outputs
def _get_mm_fields_config(
@@ -344,6 +369,7 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -378,7 +404,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
- return PromptUpdateDetails.select_token_id(tokens, image_token_id)
+ return tokens
return [
PromptReplacement(
@@ -586,9 +612,17 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ embed_is_patch = flatten_bn(embed_is_patch)
+
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
+ embed_is_patch=embed_is_patch,
)
return LlavaImagePixelInputs(
@@ -680,7 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ if image_input["type"] != "pixel_values_pixtral":
+ # The path is used for pixtral (V0 only) and llava (V0/V1)
+ return image_features
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -692,7 +735,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py
index a4fb0cb1..c74e086d 100644
--- a/vllm/model_executor/models/minicpmo.py
+++ b/vllm/model_executor/models/minicpmo.py
@@ -40,8 +40,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
-from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
- PromptUpdateDetails)
+from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
@@ -51,6 +50,7 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
+from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu")
@@ -73,6 +73,14 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]`
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which audio embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_audios, num_embeds)`
+ """
+
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
@@ -85,6 +93,14 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which audio embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_audios, num_embeds)`
+ """
+
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs]
@@ -99,6 +115,7 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
+ audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
@@ -180,7 +197,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
- return (cnn_feat_in_chunk - pool_step) // pool_step + 1
+ num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
+ return num_audio_tokens + 2 #
def get_max_audio_chunks_with_most_features(self) -> int:
return 30
@@ -191,7 +209,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
- num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
+ # exclude
+ num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features(
@@ -276,6 +295,13 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
+
+ audio_lens = [
+ self.info.get_audio_len_by_num_chunks(
+ sum(map(len,
+ parsed_audios.get(i)["audio_embeds"])))
+ for i in range(len(parsed_audios))
+ ]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
@@ -297,7 +323,27 @@ class MiniCPMOMultiModalProcessor(
]
audio_inputs["audio_features"] = unpadded_audio_features
+ audio_lens = [
+ parsed_audios.get_audio_length(i)
+ for i in range(len(parsed_audios))
+ ]
+
+ audio_repl_features = [
+ self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
+ ]
+
tokenizer = self.info.get_tokenizer()
+ audio_repls_feature_tokens = [
+ tokenizer.encode(audio_repl, add_special_tokens=False)
+ for audio_repl in audio_repl_features
+ ]
+
+ embed_is_patch = [
+ self.get_embed_is_patch(audio_repl_tokens)
+ for audio_repl_tokens in audio_repls_feature_tokens
+ ]
+ audio_inputs["audio_embed_is_patch"] = embed_is_patch
+
unk_token_id = tokenizer.get_vocab()[""]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
@@ -338,10 +384,7 @@ class MiniCPMOMultiModalProcessor(
else:
audio_len = audios.get_audio_length(item_idx)
- return PromptUpdateDetails.select_text(
- self.get_audio_prompt_texts(audio_len),
- "",
- )
+ return self.get_audio_prompt_texts(audio_len)
return [
*base_updates,
@@ -670,6 +713,13 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
+ audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
+ if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of audio_embed_is_patch. "
+ f"Got type: {type(audio_embed_is_patch)}")
+
+ audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
+
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
@@ -680,6 +730,7 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=audio_embeds_flat,
+ embed_is_patch=audio_embed_is_patch,
)
if not isinstance(audio_features, (torch.Tensor, list)):
@@ -698,6 +749,7 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
+ embed_is_patch=audio_embed_is_patch,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -729,6 +781,10 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
- multimodal_embeddings += tuple(audio_features)
+ multimodal_embeddings += tuple(
+ scatter_patch_features(
+ audio_features,
+ audio_input["embed_is_patch"],
+ ))
return multimodal_embeddings
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index eb20a963..5fab9df3 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -56,7 +56,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
- PromptUpdate, PromptUpdateDetails)
+ PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@@ -67,6 +67,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
@@ -89,6 +90,14 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
@@ -103,6 +112,14 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
@@ -228,10 +245,12 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
+ video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
@@ -379,43 +398,22 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id=use_image_id,
)
- def get_sliced_grid(
- self,
- image_size: ImageSize,
- # For MiniCPM V/O 2.6
- max_slice_nums: Optional[int] = None,
- ) -> Optional[tuple[int, int]]:
- image_processor = self.get_image_processor()
- version = self.get_model_version()
-
- if version == (2, 0) or version == (2, 5):
- return image_processor.get_sliced_grid(image_size)
-
- if max_slice_nums is None:
- max_slice_nums = image_processor.max_slice_nums
-
- return image_processor.get_sliced_grid(
- image_size,
- max_slice_nums=max_slice_nums,
- )
-
def get_num_image_tokens(
self,
image_size: ImageSize,
max_slice_nums: Optional[int] = None,
+ use_image_id: bool = True,
) -> int:
- image_processor = self.get_image_processor()
-
- grid = self.get_sliced_grid(
+ tokenizer = self.get_tokenizer()
+ image_placeholders = self.get_slice_image_placeholder(
image_size,
max_slice_nums=max_slice_nums,
+ use_image_id=use_image_id,
)
- if grid is None:
- ncols = nrows = 0
- else:
- ncols, nrows = grid
+ image_token_ids = tokenizer.encode(image_placeholders,
+ add_special_tokens=False)
- return (ncols * nrows + 1) * image_processor.image_feature_size
+ return len(image_token_ids)
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
@@ -435,6 +433,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens(
frame_size,
max_slice_nums=self.get_video_max_slice_num(),
+ use_image_id=False,
)
def get_max_video_tokens(
@@ -540,6 +539,14 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id=False,
) * num_frames
+ def get_embed_is_patch(
+ self,
+ input_ids: list[int],
+ ) -> torch.Tensor:
+ tokenizer = self.info.get_tokenizer()
+ unk_token_id = tokenizer.get_vocab()[""]
+ return torch.tensor(input_ids) == unk_token_id
+
def process_images(
self,
mm_data: Mapping[str, object],
@@ -563,7 +570,26 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
+ image_sizes = [
+ parsed_images.get_image_size(i) for i in range(len(parsed_images))
+ ]
+ image_repl_features = [
+ self.get_image_prompt_texts(size, idx)
+ for idx, size in enumerate(image_sizes)
+ ]
+
tokenizer = self.info.get_tokenizer()
+ image_repls_feature_tokens = [
+ tokenizer.encode(image_repl, add_special_tokens=False)
+ for image_repl in image_repl_features
+ ]
+
+ embed_is_patch = [
+ self.get_embed_is_patch(image_repl_tokens)
+ for image_repl_tokens in image_repls_feature_tokens
+ ]
+ image_inputs["embed_is_patch"] = embed_is_patch
+
unk_token_id = tokenizer.get_vocab()[""]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
@@ -599,9 +625,31 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
- video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
+ frame_sizes = [
+ parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
+ ]
+ num_frames = [
+ parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
+ ]
+ video_repl_features = [
+ self.get_video_prompt_texts(size, nframes)
+ for size, nframes in zip(frame_sizes, num_frames)
+ ]
tokenizer = self.info.get_tokenizer()
+ video_repls_feature_tokens = [
+ tokenizer.encode(video_repl, add_special_tokens=False)
+ for video_repl in video_repl_features
+ ]
+
+ embed_is_patch = [
+ self.get_embed_is_patch(video_repl_tokens)
+ for video_repl_tokens in video_repls_feature_tokens
+ ]
+ video_inputs["embed_is_patch"] = embed_is_patch
+
+ video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
+
unk_token_id = tokenizer.get_vocab()[""]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
@@ -692,10 +740,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size = images.get_image_size(item_idx)
- return PromptUpdateDetails.select_text(
- self.get_image_prompt_texts(image_size, item_idx),
- "",
- )
+ return self.get_image_prompt_texts(image_size, item_idx)
def get_video_replacement(item_idx: int):
videos = mm_items.get_items(
@@ -704,10 +749,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size = videos.get_frame_size(item_idx)
num_frames = videos.get_num_frames(item_idx)
- return PromptUpdateDetails.select_text(
- self.get_video_prompt_texts(frame_size, num_frames),
- "",
- )
+ return self.get_video_prompt_texts(frame_size, num_frames)
get_replacement = {
"image": get_image_replacement,
@@ -790,6 +832,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item())
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError(
+ f"Incorrect type of embed_is_patch for {modality=}. "
+ f"Got type: {type(embed_is_patch)}")
+
+ embed_is_patch = flatten_bn(embed_is_patch)
+
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
@@ -801,6 +851,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds_flat,
+ embed_is_patch=embed_is_patch,
)
if not isinstance(pixel_values, (torch.Tensor, list)):
@@ -828,6 +879,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=tgt_sizes_flat,
+ embed_is_patch=embed_is_patch,
num_slices=num_slices_flat,
)
@@ -884,11 +936,19 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if modality == "images":
image_input = modalities["images"]
image_features = self._process_vision_input(image_input)
- multimodal_embeddings += tuple(image_features)
+ multimodal_embeddings += tuple(
+ scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ ))
if modality == "videos":
video_input = modalities["videos"]
video_features = self._process_vision_input(video_input)
- multimodal_embeddings += tuple(video_features)
+ multimodal_embeddings += tuple(
+ scatter_patch_features(
+ video_features,
+ video_input["embed_is_patch"],
+ ))
return multimodal_embeddings
@@ -911,7 +971,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
list(self.mm_token_ids),
)
return inputs_embeds
diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py
index b6fbc6b1..872769dd 100644
--- a/vllm/model_executor/models/mistral3.py
+++ b/vllm/model_executor/models/mistral3.py
@@ -27,8 +27,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
- PromptReplacement, PromptUpdate,
- PromptUpdateDetails)
+ PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -36,7 +35,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
-from .vision import get_vision_encoder_info
+from .vision import (get_vision_encoder_info, scatter_patch_features,
+ select_patch_features)
class Mistral3ImagePixelInputs(TypedDict):
@@ -49,6 +49,14 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_images, num_embeds)`
+ """
+
class Mistral3PatchMerger(nn.Module):
"""
@@ -258,6 +266,23 @@ class Mistral3MultiModalProcessor(
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
]
+ hf_config = self.info.get_hf_config()
+ vision_config = hf_config.vision_config
+ assert isinstance(vision_config, PixtralVisionConfig)
+ encoder_info = PixtralHFEncoderInfo(vision_config)
+
+ tile_sizes = [
+ encoder_info.get_patch_grid_size(
+ image_width=pixel_value.shape[-1],
+ image_height=pixel_value.shape[-2],
+ ) for pixel_value in processed_outputs["pixel_values"]
+ ]
+ embed_is_patch = [
+ torch.tensor(([True] * ncols + [False]) * nrows)
+ for ncols, nrows in tile_sizes
+ ]
+ processed_outputs["embed_is_patch"] = embed_is_patch
+
return processed_outputs
def _get_mm_fields_config(
@@ -267,6 +292,7 @@ class Mistral3MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -301,7 +327,7 @@ class Mistral3MultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
- return PromptUpdateDetails.select_token_id(tokens, image_token_id)
+ return tokens
return [
PromptReplacement(
@@ -392,6 +418,8 @@ def init_vision_tower_for_llava(
)
+# TODO(mgoin): Support V1, there are issues with image batching/chunking
+# that need to be resolved first.
@MULTIMODAL_REGISTRY.register_processor(
_build_mistral3_processor,
info=_build_mistral3_info,
@@ -481,9 +509,16 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
+ assert self.config.vision_config.model_type == "pixtral"
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
+ embed_is_patch=flatten_bn(embed_is_patch),
)
def _process_image_input(
@@ -522,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input)
- return vision_embeddings
+ return scatter_patch_features(
+ vision_embeddings,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -534,7 +572,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index 6857bfa8..b2f79515 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -46,8 +46,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
- PromptInsertion, PromptUpdate,
- PromptUpdateDetails)
+ PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -57,6 +56,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@@ -84,6 +84,14 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
@@ -1138,6 +1146,30 @@ class MolmoProcessorWrapper:
if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0
+ input_is_embed = torch.isin(
+ input_ids,
+ torch.tensor([
+ self.image_patch_id,
+ self.im_col_id,
+ self.im_start_id,
+ self.im_end_id,
+ ]),
+ )
+ embed_ids = input_ids[input_is_embed]
+ embed_is_patch = embed_ids == self.image_patch_id
+ assert embed_is_patch.sum() == feat_is_patch.sum()
+
+ # image_tokens = extra_joint + joint
+ # Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
+ embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
+ embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
+ assert len(embed_start) == len(embed_end) == len(images)
+
+ embed_is_patch = [
+ embed_is_patch[start:end + 1]
+ for start, end in zip(embed_start, embed_end)
+ ]
+
tilings = [
self.select_tiling(
image_width=image.size[0],
@@ -1149,6 +1181,7 @@ class MolmoProcessorWrapper:
assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch
+ outputs["embed_is_patch"] = embed_is_patch
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
@@ -1187,13 +1220,17 @@ class MolmoProcessingInfo(BaseProcessingInfo):
)
pooling_size = processor.pooling_size
- image_token_length_w = processor.image_token_length_w
- image_token_length_h = processor.image_token_length_h
+ base_image_input_size = processor.base_image_input_size
+ base_image_input_d = processor.image_patch_size
- extra = image_token_length_w * image_token_length_h
- joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
+ crop_patches = base_image_input_size[0] // base_image_input_d
- return extra + joint
+ per_row = ncols // pooling_size + 1
+ joint = per_row * (nrows // pooling_size) + 2
+ image_token_length = (crop_patches + pooling_size - 1) // pooling_size
+ resize = (image_token_length + 1) * image_token_length + 2
+
+ return resize + joint
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
@@ -1291,6 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
@@ -1330,10 +1368,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id])
- return PromptUpdateDetails.select_token_id(
- extra_joint + joint,
- embed_token_id=img_patch_id,
- )
+ image_tokens = extra_joint + joint
+ return image_tokens
return [
PromptInsertion(
@@ -1439,6 +1475,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
@@ -1450,12 +1491,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
+ embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs(
images=images,
image_masks=image_masks,
feat_is_patch=feat_is_patch,
+ embed_is_patch=embed_is_patch,
num_crops=num_crops,
)
@@ -1494,7 +1537,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -1508,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.img_patch_id,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index 314f75c2..9d04f30c 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find ""
- return PromptUpdateDetails.select_text(repl, IMG_PAD)
+ return PromptUpdateDetails(full=repl, features=repl)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
@@ -84,6 +84,31 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**kwargs,
)
+ def get_max_image_tokens(self) -> int:
+ hf_processor = self.get_hf_processor()
+ tokenizer = hf_processor.tokenizer
+
+ max_num_patches = hf_processor.max_dynamic_patch
+ # we need +1 here because max_dynamic_patch in config doesn't
+ # include the thumbnail patch
+ tile_pos_identifiers = [
+ f"" for i in range(max_num_patches)
+ ]
+ if hf_processor.use_thumbnail and max_num_patches != 1:
+ tile_pos_identifiers += [""]
+
+ # "<", "tile"]
+ # so we include in the start_str
+ start_str = "" + tile_pos_identifiers.pop(0)
+ end_str = ""
+ start_token_len = len(tokenizer.encode(start_str))
+ end_token_len = len(tokenizer.encode(end_str))
+ tile_token_len = sum(
+ len(tokenizer.encode(identifier))
+ for identifier in tile_pos_identifiers)
+ non_image_tokens_num = start_token_len + end_token_len + tile_token_len
+ return super().get_max_image_tokens() + non_image_tokens_num
+
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
@@ -152,7 +177,10 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl = hf_processor.get_image_repl(feature_size, num_patches)
- return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
+ return PromptUpdateDetails(
+ full=repl.full + "\n",
+ features=repl.features + "\n",
+ )
# See note in dummy data regarding why we have the extra newline
return [
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index 845f77ac..6fedb8c8 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []),
- insertion=PromptUpdateDetails.select_token_id(
- image_tokens + [bos_token_id],
- embed_token_id=image_token_id,
+ insertion=PromptUpdateDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
),
)
]
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index d3b0688f..d5c64989 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
- PromptReplacement, PromptUpdate)
+ PromptReplacement, PromptUpdate,
+ PromptUpdateDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -442,7 +443,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
processor=hf_processor,
)
- return [_IMAGE_TOKEN_ID] * num_image_tokens
+ image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
+
+ return PromptUpdateDetails(
+ full=image_tokens,
+ features=image_tokens,
+ )
num_images = mm_items.get_count("image", strict=False)
@@ -511,7 +517,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
tokens=p.tokens,
- is_embed=p.is_embed,
) for p in ps
]
for modality, ps in placeholders.items()
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index e07c6516..f8c7cc93 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
- PromptUpdate, PromptUpdateDetails)
+ PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
@@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
-from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
+from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
+ scatter_patch_features, select_patch_features)
try:
from xformers import ops as xops
@@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class PixtralProcessorAdapter:
"""
@@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
+ image_token_id = self.image_token_id
+
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
+ images_embed_is_patch = list[torch.Tensor]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
+ images_embed_is_patch.append(image_tokens == image_token_id)
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
+ "embed_is_patch": images_embed_is_patch,
}
@@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)))
- return ncols * nrows
+ return (ncols + 1) * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
@@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
- return dict(images=MultiModalFieldConfig.batched("image"))
+ return dict(
+ images=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
+ )
def _get_prompt_updates(
self,
@@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
- return PromptUpdateDetails.select_token_id(tokens, image_token_id)
+ return tokens
return [
PromptReplacement(
@@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
+ embed_is_patch = kwargs.pop("embed_is_patch")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ embed_is_patch = flatten_bn(embed_is_patch)
+
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
+ embed_is_patch=embed_is_patch,
)
def _process_image_input(
@@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.vision_args.image_token_id,
)
return inputs_embeds
@@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
image_width=image_width,
image_height=image_height,
)
- return ncols * nrows
+
+ # Consider the image_break_token
+ return (ncols + 1) * nrows
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size()
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index 54220037..ccb5a3f6 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features
- return PromptUpdateDetails.select_token_id(
- [audio_bos_id] + audio_tokens + [audio_eos_id],
- embed_token_id=audio_token_id,
+ return PromptUpdateDetails(
+ full=[audio_bos_id] + audio_tokens + [audio_eos_id],
+ features=audio_tokens,
)
return [
diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py
index a2ec9a9a..4e9d02ae 100644
--- a/vllm/model_executor/models/qwen_vl.py
+++ b/vllm/model_executor/models/qwen_vl.py
@@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement(
modality="image",
target=[img_start_id, img_end_id],
- replacement=PromptUpdateDetails.select_token_id(
- [img_start_id] + image_tokens + [img_end_id],
- embed_token_id=img_pad_id,
+ replacement=PromptUpdateDetails(
+ full=[img_start_id] + image_tokens + [img_end_id],
+ features=image_tokens,
),
)
]
diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py
index e3deae82..ac5de0e3 100644
--- a/vllm/model_executor/models/skyworkr1v.py
+++ b/vllm/model_executor/models/skyworkr1v.py
@@ -40,6 +40,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
+from .vision import scatter_patch_features, select_patch_features
IMG_START = '
'
IMG_END = ''
@@ -60,6 +61,14 @@ class SkyworkR1VImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
+ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size * num_images, num_embeds)`
+ """
+
class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -410,13 +419,24 @@ class BaseSkyworkR1VProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]),
}
+ tokenizer = self.tokenizer
+ image_token_id = self.image_token_id
+
+ embed_is_patch = list[torch.Tensor]()
+
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
+ feature_tokens = tokenizer.encode(image_repl.features,
+ add_special_tokens=False)
text = [t.replace('', image_repl.full, 1) for t in text]
+ embed_is_patch.append(
+ torch.tensor(feature_tokens) == image_token_id)
+
+ image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
@@ -440,7 +460,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
- return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
+ return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@@ -579,6 +599,7 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
+ embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@@ -814,6 +835,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@@ -842,14 +864,20 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
+ if not isinstance(embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
+ embed_is_patch = flatten_bn(embed_is_patch)
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
+ embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
@@ -895,7 +923,15 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
- return self._process_image_input(image_input)
+ image_features = self._process_image_input(image_input)
+
+ if image_input["type"] != "pixel_values":
+ return image_features
+
+ return scatter_patch_features(
+ image_features,
+ image_input["embed_is_patch"],
+ )
def get_input_embeddings(
self,
@@ -909,7 +945,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
- multimodal_embeddings,
+ select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds
diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py
index 347f5149..9e00da68 100644
--- a/vllm/model_executor/models/vision.py
+++ b/vllm/model_executor/models/vision.py
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
-from typing import Final, Generic, Optional, Protocol, TypeVar, Union
+from collections.abc import Sequence
+from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
from transformers import PretrainedConfig
@@ -9,9 +10,12 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
+from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
+from .interfaces import MultiModalEmbeddings
+
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
@@ -151,3 +155,74 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
+
+
+def scatter_patch_features(
+ patches: Union[torch.Tensor, Sequence[torch.Tensor]],
+ embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
+) -> tuple[torch.Tensor, ...]:
+ """
+ Scatter the patch features into a contiguous tensor that corresponds
+ to the embedding tokens defined by the multimodal processor.
+
+ The rest of the values in the tensor are set to NaN so that they
+ can be filtered out by :func`select_patch_features`.
+
+ Args:
+ patches: The patch features for each image.
+ Shape: `(num_images, , feature_depth)`
+ embed_is_patch: A boolean mask indicating which image embeddings
+ correspond to patch tokens for each image.
+ Shape: `(num_images, num_embeds)`
+
+ Note:
+ The original code only considers patch tokens as feature
+ tokens, but our processor considers all image-related tokens
+ as feature tokens because the feature tokens need to be
+ consecutive in `input_ids`.
+
+ Example:
+ A simplified example for one image:
+
+ .. code-block::
+
+ Embedding tokens (from HF processor):
+ [ ]
+
+ embed_is_patch (from HF processor):
+ [ False True True False True True False False ]
+
+ Encoder outputs (from model):
+ [ p1 p2 p3 p4 ]
+
+ The resulting embedding tensor is:
+ [ nan p1 p2 nan p3 p4 nan nan ]
+ """
+ if len(patches) != len(embed_is_patch):
+ raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
+ f"{len(embed_is_patch)=}")
+
+ def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
+ embed_one = patches_one.new_full(
+ (e_is_patch.shape[0], patches_one.shape[-1]),
+ fill_value=torch.nan,
+ )
+ embed_one[e_is_patch] = patches_one
+ return embed_one
+
+ return tuple(
+ get_embed_one(patches_one, e_is_patch)
+ for patches_one, e_is_patch in zip(patches, embed_is_patch))
+
+
+def select_patch_features(
+ multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
+ """
+ Given the outputs of :func:`scatter_patch_features`, return only
+ the values that correspond to patch features.
+ """
+ selected_features = json_map_leaves(
+ lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
+ cast(JSONTree[torch.Tensor], multimodal_embeddings),
+ )
+ return cast(MultiModalEmbeddings, selected_features)
diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py
index ad95b982..5159b0bc 100644
--- a/vllm/multimodal/base.py
+++ b/vllm/multimodal/base.py
@@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
- placeholder_dict.offset,
- placeholder_dict.offset + placeholder_dict.length,
+ placeholder_dict["offset"],
+ placeholder_dict["offset"] + placeholder_dict["length"],
)
intersection = range(
max(positions.start, placeholder.start),
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index 53729799..81d72ff1 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -109,8 +109,7 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""
-@dataclass(frozen=True)
-class PlaceholderRange:
+class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
@@ -122,8 +121,8 @@ class PlaceholderRange:
.. code-block::
- A: PlaceholderRange(offset=0, length=4)
- B: PlaceholderRange(offset=5, length=4)
+ A: { "offset": 0, "length": 4 }
+ B: { "offset": 5, "length": 4 }
"""
offset: int
@@ -132,31 +131,6 @@ class PlaceholderRange:
length: int
"""The length of the placeholder."""
- is_embed: Optional[torch.Tensor] = None
- """
- A boolean mask of shape `(length,)` indicating which positions
- between `offset` and `offset + length` to assign embeddings to.
- """
-
- def get_num_embeds(self) -> int:
- if self.is_embed is None:
- return self.length
-
- return int(self.is_embed.sum().item())
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, self.__class__):
- return False
- if not (self.offset, self.length) == (other.offset, other.length):
- return False
-
- if self.is_embed is None:
- return other.is_embed is None
- if other.is_embed is None:
- return self.is_embed is None
-
- return nested_tensors_equal(self.is_embed, other.is_embed)
-
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
tuple[torch.Tensor, ...]]
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index a37d2975..c8864c33 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -108,46 +108,16 @@ class PromptUpdateDetails(Generic[_S]):
full: _S
"""The full content."""
- is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
+ features: _S
"""
- Given :attr:`full`, return a boolean mask of shape `(len(full),)`
- indicating which positions of `full` to assign embeddings to.
-
- `None` (default) means to assign embeddings to all positions of `full`.
-
- The embeddings are obtained by calling
- :class:`SupportsMultiModal.get_multimodal_embeddings`.
+ The part of the content that corresponds to feature placeholders;
+ this will be replaced by the output of the vision encoder during model
+ inference.
"""
@staticmethod
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
- return PromptUpdateDetails(full=seq)
-
- @staticmethod
- def select_text(
- seq: _S,
- embed_text: str,
- ) -> "PromptUpdateDetails[_S]":
-
- def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
- embed_token_ids = encode_tokens(full.tokenizer, embed_text)
-
- return torch.isin(
- torch.tensor(full.token_ids),
- torch.tensor(embed_token_ids),
- )
-
- return PromptUpdateDetails(full=seq, is_embed=is_embed)
-
- @staticmethod
- def select_token_id(
- seq: _S,
- embed_token_id: int,
- ) -> "PromptUpdateDetails[_S]":
- return PromptUpdateDetails(
- full=seq,
- is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
- )
+ return PromptUpdateDetails(full=seq, features=seq)
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
@@ -436,7 +406,7 @@ class _BoundPromptSequence:
@dataclass
class _BoundPromptContent:
full: _BoundPromptSequence
- is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
+ features: _BoundPromptSequence
@dataclass
@@ -496,8 +466,10 @@ class BoundPromptUpdate:
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
content.full)
+ bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
+ content.features)
bound_content = _BoundPromptContent(full=bound_full,
- is_embed=content.is_embed)
+ features=bound_features)
if cache_key is not None:
self._content_cache[cache_key] = bound_content
@@ -633,19 +605,15 @@ class PlaceholderFeaturesInfo:
item_idx: int
start_idx: int
tokens: list[int]
- is_embed: Optional[torch.Tensor]
@property
def length(self) -> int:
return len(self.tokens)
def to_range(self) -> PlaceholderRange:
- # TODO: Is it worth it to optimize this by stripping the
- # leading and ending positions where `is_embed=False`?
return PlaceholderRange(
offset=self.start_idx,
length=self.length,
- is_embed=self.is_embed,
)
@@ -838,17 +806,22 @@ def _iter_placeholders(
continue
if prompt[start_idx:end_idx_full] == content_tokens_full:
- content_is_embed = content.is_embed
- if content_is_embed is not None:
- content_is_embed = content_is_embed(content.full)
+ content_tokens_feat = content.features.token_ids
- yield PlaceholderFeaturesInfo(
- modality=modality,
- item_idx=item_idx,
- start_idx=start_idx,
- tokens=content_tokens_full,
- is_embed=content_is_embed,
- )
+ try:
+ match = next(
+ iter_token_matches(content_tokens_full,
+ content_tokens_feat))
+ yield PlaceholderFeaturesInfo(
+ modality=modality,
+ item_idx=item_idx,
+ start_idx=start_idx + match.start_idx,
+ tokens=content_tokens_feat,
+ )
+ except StopIteration:
+ raise AssertionError(
+ f"{content_tokens_feat=} should be a "
+ f"subsequence of {content_tokens_full=}") from None
# Exclude overlapping matches
start_idx = end_idx_full
diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py
index 4616e4e9..1df9a1f5 100644
--- a/vllm/multimodal/profiling.py
+++ b/vllm/multimodal/profiling.py
@@ -180,7 +180,7 @@ class MultiModalProfiler(Generic[_I]):
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
- modality: sum(item.get_num_embeds() for item in placeholders)
+ modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
index 77c83f0c..fc0fb892 100644
--- a/vllm/multimodal/utils.py
+++ b/vllm/multimodal/utils.py
@@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
all_items.append((modality, placeholder, hash_value))
# Sort all items by offset
- all_items.sort(key=lambda x: x[1].offset)
+ all_items.sort(key=lambda x: x[1]['offset'])
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index afcf7e34..34bc9369 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -310,7 +310,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
- if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
+ if mm_positions[-1]["offset"] + mm_positions[-1][
+ "length"] < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
@@ -321,8 +322,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
- offset = mm_positions[curr_mm_idx].offset
- length = mm_positions[curr_mm_idx].length
+ offset = mm_positions[curr_mm_idx]["offset"]
+ length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index b3905987..81f8ad25 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
- start_pos = pos_info.offset
- num_encoder_tokens = pos_info.length
+ start_pos = pos_info["offset"]
+ num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
- start_pos = mm_positions.offset
- num_tokens = mm_positions.length
+ start_pos = mm_positions["offset"]
+ num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
index daf59fd7..490fe4e8 100644
--- a/vllm/v1/request.py
+++ b/vllm/v1/request.py
@@ -121,7 +121,7 @@ class Request:
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
- num_tokens = self.mm_positions[input_id].length
+ num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index aba71845..51380633 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -19,8 +19,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
@@ -44,8 +43,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
-from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
- scatter_mm_placeholders)
+from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING:
import xgrammar as xgr
@@ -831,22 +829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return metadata
- def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
+ def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
- mm_inputs = list[MultiModalKwargs]()
- req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
+ mm_inputs: list[MultiModalKwargs] = []
+ req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
- for input_id, pos_info in zip(
- encoder_input_ids,
- req_state.mm_positions,
- ):
+ for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
- req_ids_pos.append((req_id, input_id, pos_info))
+ req_input_ids.append((req_id, input_id))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@@ -882,23 +877,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs.append(output)
# Cache the encoder outputs.
- for (req_id, input_id, pos_info), output in zip(
- req_ids_pos,
- encoder_outputs,
- ):
+ for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
+ self.encoder_cache[req_id][input_id] = output
- self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
- output,
- is_embed=pos_info.is_embed,
- )
-
- def _gather_mm_embeddings(
+ def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
- mm_embeds: list[torch.Tensor] = []
+ encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
@@ -906,8 +894,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
- start_pos = pos_info.offset
- num_encoder_tokens = pos_info.length
+ start_pos = pos_info["offset"]
+ num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
@@ -929,16 +917,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
-
- if (is_embed := pos_info.is_embed) is not None:
- is_embed = is_embed[start_idx:end_idx]
-
- mm_embeds_item = gather_mm_placeholders(
- encoder_output[start_idx:end_idx],
- is_embed=is_embed,
- )
- mm_embeds.append(mm_embeds_item)
- return mm_embeds
+ encoder_outputs.append(encoder_output[start_idx:end_idx])
+ return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@@ -1003,10 +983,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.is_multimodal_model:
# Run the multimodal encoder if any.
- self._execute_mm_encoder(scheduler_output)
- mm_embeds = self._gather_mm_embeddings(scheduler_output)
+ self._execute_encoder(scheduler_output)
+ encoder_outputs = self._gather_encoder_outputs(scheduler_output)
else:
- mm_embeds = []
+ encoder_outputs = []
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
@@ -1028,9 +1008,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
- if mm_embeds:
+ if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
- input_ids, mm_embeds)
+ input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py
index 488912fb..0668e716 100644
--- a/vllm/v1/worker/tpu_model_runner.py
+++ b/vllm/v1/worker/tpu_model_runner.py
@@ -19,8 +19,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
@@ -37,8 +36,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
-from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
- scatter_mm_placeholders)
+from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -509,47 +507,19 @@ class TPUModelRunner:
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
- def _scatter_placeholders(
- self,
- embeds: torch.Tensor,
- is_embed: Optional[torch.Tensor],
- ) -> torch.Tensor:
- if is_embed is None:
- return embeds
-
- placeholders = embeds.new_full(
- (is_embed.shape[0], embeds.shape[-1]),
- fill_value=torch.nan,
- )
- placeholders[is_embed] = embeds
- return placeholders
-
- def _gather_placeholders(
- self,
- placeholders: torch.Tensor,
- is_embed: Optional[torch.Tensor],
- ) -> torch.Tensor:
- if is_embed is None:
- return placeholders
-
- return placeholders[is_embed]
-
- def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
+ def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
- mm_inputs = list[MultiModalKwargs]()
- req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
+ mm_inputs: list[MultiModalKwargs] = []
+ req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
- for input_id, pos_info in zip(
- encoder_input_ids,
- req_state.mm_positions,
- ):
+ for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
- req_ids_pos.append((req_id, input_id, pos_info))
+ req_input_ids.append((req_id, input_id))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@@ -585,23 +555,16 @@ class TPUModelRunner:
encoder_outputs.append(output)
# Cache the encoder outputs.
- for (req_id, input_id, pos_info), output in zip(
- req_ids_pos,
- encoder_outputs,
- ):
+ for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
+ self.encoder_cache[req_id][input_id] = output
- self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
- output,
- is_embed=pos_info.is_embed,
- )
-
- def _gather_mm_embeddings(
+ def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
- mm_embeds: list[torch.Tensor] = []
+ encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
@@ -609,8 +572,8 @@ class TPUModelRunner:
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
- start_pos = pos_info.offset
- num_encoder_tokens = pos_info.length
+ start_pos = pos_info["offset"]
+ num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
@@ -632,16 +595,8 @@ class TPUModelRunner:
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
-
- if (is_embed := pos_info.is_embed) is not None:
- is_embed = is_embed[start_idx:end_idx]
-
- mm_embeds_item = gather_mm_placeholders(
- encoder_output[start_idx:end_idx],
- is_embed=is_embed,
- )
- mm_embeds.append(mm_embeds_item)
- return mm_embeds
+ encoder_outputs.append(encoder_output[start_idx:end_idx])
+ return encoder_outputs
@torch.no_grad()
def execute_model(
@@ -657,10 +612,10 @@ class TPUModelRunner:
if self.is_multimodal_model:
# Run the multimodal encoder if any.
- self._execute_mm_encoder(scheduler_output)
- mm_embeds = self._gather_mm_embeddings(scheduler_output)
+ self._execute_encoder(scheduler_output)
+ encoder_outputs = self._gather_encoder_outputs(scheduler_output)
else:
- mm_embeds = []
+ encoder_outputs = []
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
@@ -668,9 +623,9 @@ class TPUModelRunner:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
- if mm_embeds:
+ if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
- self.input_ids, mm_embeds)
+ self.input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
index e46ca0c9..b1d3aa7c 100644
--- a/vllm/v1/worker/utils.py
+++ b/vllm/v1/worker/utils.py
@@ -1,6 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
-from typing import Optional
-
import torch
@@ -29,46 +27,3 @@ def sanity_check_mm_encoder_outputs(
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
-
-
-def scatter_mm_placeholders(
- embeds: torch.Tensor,
- is_embed: Optional[torch.Tensor],
-) -> torch.Tensor:
- """
- Scatter the multimodal embeddings into a contiguous tensor that represents
- the placeholder tokens.
-
- :class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
-
- Args:
- embeds: The multimodal embeddings.
- Shape: `(num_embeds, embed_dim)`
- is_embed: A boolean mask indicating which positions in the placeholder
- tokens need to be filled with multimodal embeddings.
- Shape: `(num_placeholders, num_embeds)`
- """
- if is_embed is None:
- return embeds
-
- placeholders = embeds.new_full(
- (is_embed.shape[0], embeds.shape[-1]),
- fill_value=torch.nan,
- )
- placeholders[is_embed] = embeds
- return placeholders
-
-
-def gather_mm_placeholders(
- placeholders: torch.Tensor,
- is_embed: Optional[torch.Tensor],
-) -> torch.Tensor:
- """
- Reconstructs the embeddings from the placeholder tokens.
-
- This is the operation of :func:`scatter_mm_placeholders`.
- """
- if is_embed is None:
- return placeholders
-
- return placeholders[is_embed]