[V1] Scatter and gather placeholders in the model runner (#15712)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
651cf0fec1
commit
f5722a5052
@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
)
|
||||
```
|
||||
|
||||
To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`
|
||||
with different `full` and `feature` attributes:
|
||||
To assign the vision embeddings to only the image tokens, instead of a string
|
||||
you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
|
||||
|
||||
```python
|
||||
hf_config = self.info.get_hf_config()
|
||||
@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
```
|
||||
|
||||
@ -914,9 +914,9 @@ def _get_prompt_updates(
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
return [
|
||||
|
@ -989,9 +989,6 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
:::{important}
|
||||
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
|
||||
`pip install git+https://github.com/huggingface/transformers`.
|
||||
|
||||
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
|
||||
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
|
||||
:::
|
||||
|
@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
|
@ -55,7 +55,10 @@ def server(request, audio_assets):
|
||||
for key, value in request.param.items()
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
"30"}) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
|
@ -167,7 +167,7 @@ VLM_TEST_SETTINGS = {
|
||||
"cherry_blossom": "<image>What is the season?", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
|
||||
|
@ -176,6 +176,8 @@ def test_chat(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
load_format="mistral",
|
||||
config_format="mistral",
|
||||
max_model_len=max_model_len,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
@ -198,22 +200,14 @@ def test_chat(
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_ranges",
|
||||
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
||||
"offset": 11,
|
||||
"length": 494
|
||||
}]),
|
||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
||||
"offset": 11,
|
||||
"length": 266
|
||||
}, {
|
||||
"offset": 277,
|
||||
"length": 1056
|
||||
}, {
|
||||
"offset": 1333,
|
||||
"length": 418
|
||||
}])])
|
||||
@pytest.mark.parametrize("prompt,expected_ranges",
|
||||
[(_create_engine_inputs_hf(IMG_URLS[:1]),
|
||||
[PlaceholderRange(offset=11, length=494)]),
|
||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [
|
||||
PlaceholderRange(offset=11, length=266),
|
||||
PlaceholderRange(offset=277, length=1056),
|
||||
PlaceholderRange(offset=1333, length=418)
|
||||
])])
|
||||
def test_multi_modal_placeholders(vllm_runner, prompt,
|
||||
expected_ranges: list[PlaceholderRange],
|
||||
monkeypatch) -> None:
|
||||
|
@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 1
|
||||
assert first_placeholder["length"] == (
|
||||
assert first_placeholder.offset == 1
|
||||
assert first_placeholder.length == (
|
||||
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
||||
|
||||
except Exception as exc:
|
||||
|
@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
assert first_placeholder["offset"] == 0
|
||||
assert first_placeholder["length"] == len(
|
||||
assert first_placeholder.offset == 0
|
||||
assert first_placeholder.length == len(
|
||||
processed_inputs["prompt_token_ids"]) // num_imgs
|
||||
except Exception as exc:
|
||||
failed_size_excs.append((image_size, exc))
|
||||
|
@ -277,7 +277,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
|
@ -785,6 +785,7 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
tokens=[32000, 32000],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
@ -793,6 +794,7 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=3,
|
||||
tokens=[32000],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
}
|
||||
@ -807,12 +809,14 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
tokens=[32000, 32000],
|
||||
is_embed=None,
|
||||
),
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
tokens=[32000, 32000],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
@ -821,6 +825,7 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
tokens=[1550, 918, 1550],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
# No match for pattern_4 as it has lower priority than pattern_1
|
||||
@ -835,12 +840,14 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
tokens=[32000, 32000],
|
||||
is_embed=None,
|
||||
),
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
tokens=[32000, 32000],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
@ -849,6 +856,7 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=5,
|
||||
tokens=[32000],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
@ -857,6 +865,7 @@ def test_find_update_tokens(
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
tokens=[1550, 918, 1550],
|
||||
is_embed=None,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
@ -158,13 +158,10 @@ def test_generate_block_hash_extra_keys():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 5
|
||||
}, {
|
||||
"offset": 10,
|
||||
"length": 5
|
||||
}],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=5),
|
||||
PlaceholderRange(offset=10, length=5),
|
||||
],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
|
||||
@ -222,13 +219,10 @@ def test_hash_request_tokens(hash_fn):
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
|
||||
@ -253,25 +247,19 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
request2 = make_request(
|
||||
request_id=1,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
],
|
||||
mm_hashes=["hash3", "hash2"],
|
||||
)
|
||||
block_size = 3
|
||||
|
@ -27,7 +27,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement, PromptUpdate,
|
||||
encode_tokens)
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -35,7 +35,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
|
||||
class AyaVisionImagePixelInputs(TypedDict):
|
||||
@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class AyaVisionMultiModalProjector(nn.Module):
|
||||
|
||||
@ -135,21 +127,20 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
|
||||
def get_max_image_tokens(self) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
image_size = self.get_image_size_with_most_features()
|
||||
tokenizer = hf_processor.tokenizer
|
||||
num_patches = self.get_num_patches(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
size=image_processor.size,
|
||||
min_patches=image_processor.min_patches,
|
||||
max_patches=image_processor.max_patches)
|
||||
image_string = hf_processor._prompt_split_image(num_patches)
|
||||
x = encode_tokens(
|
||||
tokenizer,
|
||||
image_string,
|
||||
add_special_tokens=False,
|
||||
max_patches=image_processor.max_patches,
|
||||
)
|
||||
return len(x)
|
||||
|
||||
img_patches_per_tile = (hf_processor.img_size //
|
||||
hf_processor.patch_size)**2
|
||||
|
||||
return num_patches * img_patches_per_tile
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
@ -221,7 +212,6 @@ class AyaVisionMultiModalProcessor(
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
|
||||
if (images :=
|
||||
mm_data.get("images")) is not None and '<image>' in prompt:
|
||||
@ -234,6 +224,7 @@ class AyaVisionMultiModalProcessor(
|
||||
parsed_images.get_image_size(i)
|
||||
for i in range(len(parsed_images))
|
||||
]
|
||||
|
||||
num_patches = [
|
||||
self.info.get_num_patches(
|
||||
image_width=image_size.width,
|
||||
@ -243,20 +234,6 @@ class AyaVisionMultiModalProcessor(
|
||||
max_patches=image_processor.max_patches)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
image_tokens_list = [
|
||||
hf_processor._prompt_split_image(num_patch)
|
||||
for num_patch in num_patches
|
||||
]
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_token_ids = [
|
||||
tokenizer.encode(image_tokens, add_special_tokens=False)
|
||||
for image_tokens in image_tokens_list
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == hf_config.image_token_index
|
||||
for image_repl_tokens in image_token_ids
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
processed_outputs["num_patches"] = torch.tensor(num_patches)
|
||||
|
||||
return processed_outputs
|
||||
@ -271,7 +248,6 @@ class AyaVisionMultiModalProcessor(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@ -283,6 +259,7 @@ class AyaVisionMultiModalProcessor(
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token = hf_processor.image_token
|
||||
img_patch_token = hf_processor.img_patch_token
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
@ -294,8 +271,11 @@ class AyaVisionMultiModalProcessor(
|
||||
image_height=image_size.height,
|
||||
size=image_processor.size,
|
||||
min_patches=image_processor.min_patches,
|
||||
max_patches=image_processor.max_patches)
|
||||
return hf_processor._prompt_split_image(num_patches=num_patches)
|
||||
max_patches=image_processor.max_patches,
|
||||
)
|
||||
repl = hf_processor._prompt_split_image(num_patches=num_patches)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl, img_patch_token)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -424,7 +404,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Aya Vision does not support image_embeds."
|
||||
|
||||
@ -436,18 +415,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_patches = flatten_bn(num_patches, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return AyaVisionImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
@ -455,11 +429,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
image_features = self._process_image_input(image_input, **kwargs)
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
return self._process_image_input(image_input, **kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -471,9 +442,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=select_patch_features(
|
||||
multimodal_embeddings),
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
@ -162,9 +162,9 @@ class ChameleonMultiModalProcessor(
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=PromptUpdateDetails(
|
||||
full=([image_start_id] + image_tokens + [image_end_id]),
|
||||
features=image_tokens,
|
||||
replacement=PromptUpdateDetails.select_token_id(
|
||||
[image_start_id] + image_tokens + [image_end_id],
|
||||
embed_token_id=image_token_id,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@ -18,7 +18,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -43,7 +43,6 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -66,14 +65,6 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
flattened just like `flat_data`.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@ -94,15 +85,7 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def get_image_feature_grid_size(
|
||||
self,
|
||||
@ -128,11 +111,32 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
nrows = math.ceil(image_height / patch_height)
|
||||
return ncols, nrows
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
ncols, nrows = self.get_image_feature_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
return ncols * nrows
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_image_processor()
|
||||
return ImageSize(width=image_processor.size["width"],
|
||||
height=image_processor.size["height"])
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||
|
||||
@ -192,19 +196,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
|
||||
# get patch grid size for each image
|
||||
embed_is_patch = []
|
||||
for image in images:
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image.width,
|
||||
image_height=image.height,
|
||||
)
|
||||
|
||||
mask = torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
embed_is_patch.append(mask)
|
||||
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
@ -224,8 +215,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"))
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -252,9 +242,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
return [
|
||||
@ -329,20 +319,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(image_patches)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=self._validate_pixel_values(
|
||||
flatten_bn(image_patches_flat, concat=True)),
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
return None
|
||||
@ -364,12 +347,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -379,8 +357,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
encode_tokens, find_mm_placeholders,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
@ -36,7 +36,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
@ -183,7 +174,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_token = processor.boi_token
|
||||
boi_token = processor.boi_token
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
@ -192,19 +183,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = image_token
|
||||
image_text = boi_token
|
||||
else:
|
||||
crops_image_tokens = " ".join(image_token
|
||||
for _ in range(num_crops))
|
||||
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
|
||||
image_text = (
|
||||
f"Here is the original image {image_token} and here are some "
|
||||
f"Here is the original image {boi_token} and here are some "
|
||||
f"crops to help you see better {crops_image_tokens}")
|
||||
|
||||
repl_full = image_text.replace(image_token,
|
||||
repl_full = image_text.replace(boi_token,
|
||||
processor.full_image_sequence)
|
||||
repl_features = repl_full.strip("\n")
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
tokenizer = processor.tokenizer
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@ -213,19 +206,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_repl = self.get_image_repl(
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
image_seq_len = processor.image_seq_length
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl.features,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
return (num_crops + 1) * image_seq_len
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
@ -301,28 +292,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor).features
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_crops = [
|
||||
self.info.get_num_crops(image_width=size.width,
|
||||
image_height=size.height,
|
||||
@ -344,7 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -454,6 +422,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
item_idx=p.item_idx,
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
is_embed=p.is_embed,
|
||||
) for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
@ -572,7 +541,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@ -586,19 +554,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@ -635,12 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -652,7 +609,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
|
@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement, PromptUpdate,
|
||||
encode_tokens)
|
||||
PromptUpdateDetails)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TypedDict):
|
||||
@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||
|
||||
@ -275,19 +258,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
image_height: int,
|
||||
processor: Optional[Idefics3Processor],
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_repl = self.get_image_repl(
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
num_patches = self.get_num_patches(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
return num_patches * processor.image_seq_len
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
@ -364,28 +344,6 @@ class Idefics3MultiModalProcessor(
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[hf_processor.image_token.content]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_patches = [
|
||||
self.info.get_num_patches(
|
||||
image_width=size.width,
|
||||
@ -415,7 +373,6 @@ class Idefics3MultiModalProcessor(
|
||||
"image", num_patches),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -427,17 +384,22 @@ class Idefics3MultiModalProcessor(
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_token = hf_processor.image_token.content
|
||||
|
||||
def get_replacement_idefics3(item_idx: int) -> str:
|
||||
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
return self.info.get_image_repl(
|
||||
image_repl = self.info.get_image_repl(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(
|
||||
image_repl,
|
||||
embed_text=image_token,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
@ -675,13 +637,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
@ -690,7 +645,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return Idefics3ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
@ -718,7 +672,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
num_patches=num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -754,12 +707,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -771,7 +719,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -39,7 +39,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
|
||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
feature_tokens = tokenizer.encode(image_repl.features,
|
||||
add_special_tokens=False)
|
||||
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
embed_is_patch.append(
|
||||
torch.tensor(feature_tokens) == image_token_id)
|
||||
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
|
||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||
@ -599,7 +578,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -831,7 +809,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
@ -860,20 +837,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=self._validate_pixel_values(
|
||||
pixel_values_flat),
|
||||
num_patches=image_num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -919,15 +890,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if image_input["type"] != "pixel_values":
|
||||
return image_features
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -941,7 +904,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.img_context_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -32,7 +32,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate)
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||
select_patch_features)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -343,23 +335,6 @@ class PixtralHFMultiModalProcessor(
|
||||
for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
tile_sizes = [
|
||||
encoder_info.get_patch_grid_size(
|
||||
image_width=pixel_value.shape[-1],
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@ -369,7 +344,6 @@ class PixtralHFMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@ -404,7 +378,7 @@ class PixtralHFMultiModalProcessor(
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return tokens
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -612,17 +586,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@ -714,16 +680,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if image_input["type"] != "pixel_values_pixtral":
|
||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||
return image_features
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -735,7 +692,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -40,7 +40,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
DictEmbeddingItems, ModalityData,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
|
||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
_minicpmv_field_config)
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
from .vision import scatter_patch_features
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
MiniCPMOAudioEmbeddingInputs]
|
||||
@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||
)
|
||||
|
||||
@ -197,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
pool_step = self.get_default_audio_pool_step()
|
||||
fbank_feat_in_chunk = 100
|
||||
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
|
||||
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
||||
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
|
||||
return (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
||||
|
||||
def get_max_audio_chunks_with_most_features(self) -> int:
|
||||
return 30
|
||||
@ -209,8 +191,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
|
||||
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
||||
sampling_rate = self.get_default_audio_sampling_rate()
|
||||
# exclude <audio> </audio>
|
||||
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
|
||||
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
|
||||
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
|
||||
|
||||
def get_num_frames_with_most_features(
|
||||
@ -295,13 +276,6 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||
audio_inputs = {}
|
||||
|
||||
audio_lens = [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(map(len,
|
||||
parsed_audios.get(i)["audio_embeds"])))
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
else:
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
@ -323,27 +297,7 @@ class MiniCPMOMultiModalProcessor(
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
audio_lens = [
|
||||
parsed_audios.get_audio_length(i)
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
|
||||
audio_repl_features = [
|
||||
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
audio_repls_feature_tokens = [
|
||||
tokenizer.encode(audio_repl, add_special_tokens=False)
|
||||
for audio_repl in audio_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(audio_repl_tokens)
|
||||
for audio_repl_tokens in audio_repls_feature_tokens
|
||||
]
|
||||
audio_inputs["audio_embed_is_patch"] = embed_is_patch
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
@ -384,7 +338,10 @@ class MiniCPMOMultiModalProcessor(
|
||||
else:
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
return self.get_audio_prompt_texts(audio_len)
|
||||
return PromptUpdateDetails.select_text(
|
||||
self.get_audio_prompt_texts(audio_len),
|
||||
"<unk>",
|
||||
)
|
||||
|
||||
return [
|
||||
*base_updates,
|
||||
@ -713,13 +670,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
assert isinstance(audio_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||
|
||||
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
|
||||
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embed_is_patch. "
|
||||
f"Got type: {type(audio_embed_is_patch)}")
|
||||
|
||||
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
@ -730,7 +680,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=audio_embeds_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
@ -749,7 +698,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
type="audio_features",
|
||||
audio_features=audio_features_flat,
|
||||
audio_feature_lens=audio_feature_lens_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@ -781,10 +729,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
audio_features,
|
||||
audio_input["embed_is_patch"],
|
||||
))
|
||||
multimodal_embeddings += tuple(audio_features)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
@ -56,7 +56,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
||||
VideoItem, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -67,7 +67,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@ -90,14 +89,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_slices: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
@ -112,14 +103,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||
instead of a batched tensor.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
||||
MiniCPMVImageEmbeddingInputs]
|
||||
@ -245,12 +228,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
@ -398,22 +379,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
use_image_id=use_image_id,
|
||||
)
|
||||
|
||||
def get_sliced_grid(
|
||||
self,
|
||||
image_size: ImageSize,
|
||||
# For MiniCPM V/O 2.6
|
||||
max_slice_nums: Optional[int] = None,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
image_processor = self.get_image_processor()
|
||||
version = self.get_model_version()
|
||||
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
return image_processor.get_sliced_grid(image_size)
|
||||
|
||||
if max_slice_nums is None:
|
||||
max_slice_nums = image_processor.max_slice_nums
|
||||
|
||||
return image_processor.get_sliced_grid(
|
||||
image_size,
|
||||
max_slice_nums=max_slice_nums,
|
||||
)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
image_size: ImageSize,
|
||||
max_slice_nums: Optional[int] = None,
|
||||
use_image_id: bool = True,
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_placeholders = self.get_slice_image_placeholder(
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
grid = self.get_sliced_grid(
|
||||
image_size,
|
||||
max_slice_nums=max_slice_nums,
|
||||
use_image_id=use_image_id,
|
||||
)
|
||||
image_token_ids = tokenizer.encode(image_placeholders,
|
||||
add_special_tokens=False)
|
||||
if grid is None:
|
||||
ncols = nrows = 0
|
||||
else:
|
||||
ncols, nrows = grid
|
||||
|
||||
return len(image_token_ids)
|
||||
return (ncols * nrows + 1) * image_processor.image_feature_size
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
image_size = self.get_image_size_with_most_features()
|
||||
@ -433,7 +435,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
return self.get_num_image_tokens(
|
||||
frame_size,
|
||||
max_slice_nums=self.get_video_max_slice_num(),
|
||||
use_image_id=False,
|
||||
)
|
||||
|
||||
def get_max_video_tokens(
|
||||
@ -539,14 +540,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
use_image_id=False,
|
||||
) * num_frames
|
||||
|
||||
def get_embed_is_patch(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
) -> torch.Tensor:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
return torch.tensor(input_ids) == unk_token_id
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
@ -570,26 +563,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
image_repl_features = [
|
||||
self.get_image_prompt_texts(size, idx)
|
||||
for idx, size in enumerate(image_sizes)
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(image_repl_tokens)
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
@ -625,31 +599,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
frame_sizes = [
|
||||
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
|
||||
]
|
||||
num_frames = [
|
||||
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
|
||||
]
|
||||
video_repl_features = [
|
||||
self.get_video_prompt_texts(size, nframes)
|
||||
for size, nframes in zip(frame_sizes, num_frames)
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
video_repls_feature_tokens = [
|
||||
tokenizer.encode(video_repl, add_special_tokens=False)
|
||||
for video_repl in video_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(video_repl_tokens)
|
||||
for video_repl_tokens in video_repls_feature_tokens
|
||||
]
|
||||
video_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
@ -740,7 +692,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
return self.get_image_prompt_texts(image_size, item_idx)
|
||||
return PromptUpdateDetails.select_text(
|
||||
self.get_image_prompt_texts(image_size, item_idx),
|
||||
"<unk>",
|
||||
)
|
||||
|
||||
def get_video_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
@ -749,7 +704,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
frame_size = videos.get_frame_size(item_idx)
|
||||
num_frames = videos.get_num_frames(item_idx)
|
||||
|
||||
return self.get_video_prompt_texts(frame_size, num_frames)
|
||||
return PromptUpdateDetails.select_text(
|
||||
self.get_video_prompt_texts(frame_size, num_frames),
|
||||
"<unk>",
|
||||
)
|
||||
|
||||
get_replacement = {
|
||||
"image": get_image_replacement,
|
||||
@ -832,14 +790,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of embed_is_patch for {modality=}. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
@ -851,7 +801,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds_flat,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
@ -879,7 +828,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values_flat,
|
||||
tgt_sizes=tgt_sizes_flat,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_slices=num_slices_flat,
|
||||
)
|
||||
|
||||
@ -936,19 +884,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
image_features = self._process_vision_input(image_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
multimodal_embeddings += tuple(image_features)
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_features = self._process_vision_input(video_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
video_features,
|
||||
video_input["embed_is_patch"],
|
||||
))
|
||||
multimodal_embeddings += tuple(video_features)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -971,7 +911,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
list(self.mm_token_ids),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -27,7 +27,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate)
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -35,8 +36,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||
select_patch_features)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class Mistral3ImagePixelInputs(TypedDict):
|
||||
@ -49,14 +49,6 @@ class Mistral3ImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class Mistral3PatchMerger(nn.Module):
|
||||
"""
|
||||
@ -266,23 +258,6 @@ class Mistral3MultiModalProcessor(
|
||||
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
tile_sizes = [
|
||||
encoder_info.get_patch_grid_size(
|
||||
image_width=pixel_value.shape[-1],
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@ -292,7 +267,6 @@ class Mistral3MultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@ -327,7 +301,7 @@ class Mistral3MultiModalProcessor(
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return tokens
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -418,8 +392,6 @@ def init_vision_tower_for_llava(
|
||||
)
|
||||
|
||||
|
||||
# TODO(mgoin): Support V1, there are issues with image batching/chunking
|
||||
# that need to be resolved first.
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
_build_mistral3_processor,
|
||||
info=_build_mistral3_info,
|
||||
@ -509,16 +481,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
assert self.config.vision_config.model_type == "pixtral"
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
return Mistral3ImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=flatten_bn(embed_is_patch),
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@ -557,10 +522,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
vision_embeddings,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -572,7 +534,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -46,7 +46,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate)
|
||||
PromptInsertion, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -56,7 +57,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -84,14 +84,6 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
@ -1146,30 +1138,6 @@ class MolmoProcessorWrapper:
|
||||
if image_input_idx is not None:
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
input_is_embed = torch.isin(
|
||||
input_ids,
|
||||
torch.tensor([
|
||||
self.image_patch_id,
|
||||
self.im_col_id,
|
||||
self.im_start_id,
|
||||
self.im_end_id,
|
||||
]),
|
||||
)
|
||||
embed_ids = input_ids[input_is_embed]
|
||||
embed_is_patch = embed_ids == self.image_patch_id
|
||||
assert embed_is_patch.sum() == feat_is_patch.sum()
|
||||
|
||||
# image_tokens = extra_joint + joint
|
||||
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
||||
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
||||
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
||||
assert len(embed_start) == len(embed_end) == len(images)
|
||||
|
||||
embed_is_patch = [
|
||||
embed_is_patch[start:end + 1]
|
||||
for start, end in zip(embed_start, embed_end)
|
||||
]
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
@ -1181,7 +1149,6 @@ class MolmoProcessorWrapper:
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
outputs["feat_is_patch"] = feat_is_patch
|
||||
outputs["embed_is_patch"] = embed_is_patch
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
@ -1220,17 +1187,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
base_image_input_size = processor.base_image_input_size
|
||||
base_image_input_d = processor.image_patch_size
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
extra = image_token_length_w * image_token_length_h
|
||||
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
|
||||
|
||||
per_row = ncols // pooling_size + 1
|
||||
joint = per_row * (nrows // pooling_size) + 2
|
||||
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
|
||||
resize = (image_token_length + 1) * image_token_length + 2
|
||||
|
||||
return resize + joint
|
||||
return extra + joint
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
@ -1328,7 +1291,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
"image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -1368,8 +1330,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
joint = ([img_start_id] + joint_row *
|
||||
((nrows + 1) // pooling_size) + [img_end_id])
|
||||
|
||||
image_tokens = extra_joint + joint
|
||||
return image_tokens
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
extra_joint + joint,
|
||||
embed_token_id=img_patch_id,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptInsertion(
|
||||
@ -1475,11 +1439,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
@ -1491,14 +1450,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
f"Got type: {type(img_patch_id)}")
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
@ -1537,12 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -1556,7 +1508,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.img_patch_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
|
||||
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
||||
repl = "<Image>" + features + "</Image>"
|
||||
|
||||
return PromptUpdateDetails(full=repl, features=repl)
|
||||
return PromptUpdateDetails.select_text(repl, IMG_PAD)
|
||||
|
||||
|
||||
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||
@ -84,31 +84,6 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
tokenizer = hf_processor.tokenizer
|
||||
|
||||
max_num_patches = hf_processor.max_dynamic_patch
|
||||
# we need +1 here because max_dynamic_patch in config doesn't
|
||||
# include the thumbnail patch
|
||||
tile_pos_identifiers = [
|
||||
f"<tile_{i+1}>" for i in range(max_num_patches)
|
||||
]
|
||||
if hf_processor.use_thumbnail and max_num_patches != 1:
|
||||
tile_pos_identifiers += ["<tile_global_thumbnail>"]
|
||||
|
||||
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
|
||||
# so we include <tile_1> in the start_str
|
||||
start_str = "<Image>" + tile_pos_identifiers.pop(0)
|
||||
end_str = "</Image>"
|
||||
start_token_len = len(tokenizer.encode(start_str))
|
||||
end_token_len = len(tokenizer.encode(end_str))
|
||||
tile_token_len = sum(
|
||||
len(tokenizer.encode(identifier))
|
||||
for identifier in tile_pos_identifiers)
|
||||
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
|
||||
return super().get_max_image_tokens() + non_image_tokens_num
|
||||
|
||||
|
||||
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||
|
||||
@ -177,10 +152,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
|
||||
repl = hf_processor.get_image_repl(feature_size, num_patches)
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=repl.full + "\n",
|
||||
features=repl.features + "\n",
|
||||
)
|
||||
return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
|
||||
|
||||
# See note in dummy data regarding why we have the extra newline
|
||||
return [
|
||||
|
@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.prefix(
|
||||
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||
insertion=PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
insertion=PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=image_token_id,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@ -40,8 +40,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
PromptReplacement, PromptUpdate)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -443,12 +442,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens,
|
||||
features=image_tokens,
|
||||
)
|
||||
return [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||
|
||||
num_images = mm_items.get_count("image", strict=False)
|
||||
|
||||
@ -517,6 +511,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
item_idx=p.item_idx,
|
||||
start_idx=p.start_idx - 1,
|
||||
tokens=p.tokens,
|
||||
is_embed=p.is_embed,
|
||||
) for p in ps
|
||||
]
|
||||
for modality, ps in placeholders.items()
|
||||
|
@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
@ -46,8 +46,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
|
||||
scatter_patch_features, select_patch_features)
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -68,14 +67,6 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
"""
|
||||
@ -144,11 +135,8 @@ class PixtralProcessorAdapter:
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
images_embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
@ -157,12 +145,10 @@ class PixtralProcessorAdapter:
|
||||
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
"embed_is_patch": images_embed_is_patch,
|
||||
}
|
||||
|
||||
|
||||
@ -213,7 +199,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_width, image_height)))
|
||||
|
||||
return (ncols + 1) * nrows
|
||||
return ncols * nrows
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
@ -263,10 +249,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
return dict(images=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -290,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return tokens
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -381,17 +364,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of images. "
|
||||
f"Got type: {type(images)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@ -427,12 +402,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -444,7 +414,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.vision_args.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
@ -963,9 +933,7 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (ncols + 1) * nrows
|
||||
return ncols * nrows
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
image_size = self.get_image_size()
|
||||
|
@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
return PromptUpdateDetails(
|
||||
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||
features=audio_tokens,
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
|
||||
return [
|
||||
|
@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[img_start_id, img_end_id],
|
||||
replacement=PromptUpdateDetails(
|
||||
full=[img_start_id] + image_tokens + [img_end_id],
|
||||
features=image_tokens,
|
||||
replacement=PromptUpdateDetails.select_token_id(
|
||||
[img_start_id] + image_tokens + [img_end_id],
|
||||
embed_token_id=img_pad_id,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@ -40,7 +40,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -61,14 +60,6 @@ class SkyworkR1VImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class SkyworkR1VImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -419,24 +410,13 @@ class BaseSkyworkR1VProcessor(ABC):
|
||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
feature_tokens = tokenizer.encode(image_repl.features,
|
||||
add_special_tokens=False)
|
||||
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
embed_is_patch.append(
|
||||
torch.tensor(feature_tokens) == image_token_id)
|
||||
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
@ -460,7 +440,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
|
||||
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
@ -599,7 +579,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -835,7 +814,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
@ -864,20 +842,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return SkyworkR1VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=self._validate_pixel_values(
|
||||
pixel_values_flat),
|
||||
num_patches=image_num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -923,15 +895,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if image_input["type"] != "pixel_values":
|
||||
return image_features
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -945,7 +909,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.img_context_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -10,12 +9,9 @@ from transformers import PretrainedConfig
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
from .interfaces import MultiModalEmbeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
@ -155,74 +151,3 @@ def resolve_visual_encoder_outputs(
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def scatter_patch_features(
|
||||
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
The rest of the values in the tensor are set to NaN so that they
|
||||
can be filtered out by :func`select_patch_features`.
|
||||
|
||||
Args:
|
||||
patches: The patch features for each image.
|
||||
Shape: `(num_images, <patch_dims>, feature_depth)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
|
||||
Note:
|
||||
The original code only considers patch tokens as feature
|
||||
tokens, but our processor considers all image-related tokens
|
||||
as feature tokens because the feature tokens need to be
|
||||
consecutive in `input_ids`.
|
||||
|
||||
Example:
|
||||
A simplified example for one image:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embedding tokens (from HF processor):
|
||||
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||
|
||||
embed_is_patch (from HF processor):
|
||||
[ False True True False True True False False ]
|
||||
|
||||
Encoder outputs (from model):
|
||||
[ p1 p2 p3 p4 ]
|
||||
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
if len(patches) != len(embed_is_patch):
|
||||
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
|
||||
f"{len(embed_is_patch)=}")
|
||||
|
||||
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
|
||||
embed_one = patches_one.new_full(
|
||||
(e_is_patch.shape[0], patches_one.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embed_one[e_is_patch] = patches_one
|
||||
return embed_one
|
||||
|
||||
return tuple(
|
||||
get_embed_one(patches_one, e_is_patch)
|
||||
for patches_one, e_is_patch in zip(patches, embed_is_patch))
|
||||
|
||||
|
||||
def select_patch_features(
|
||||
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||
"""
|
||||
Given the outputs of :func:`scatter_patch_features`, return only
|
||||
the values that correspond to patch features.
|
||||
"""
|
||||
selected_features = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
return cast(MultiModalEmbeddings, selected_features)
|
||||
|
@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
|
||||
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
|
||||
multi_modal_items):
|
||||
placeholder = range(
|
||||
placeholder_dict["offset"],
|
||||
placeholder_dict["offset"] + placeholder_dict["length"],
|
||||
placeholder_dict.offset,
|
||||
placeholder_dict.offset + placeholder_dict.length,
|
||||
)
|
||||
intersection = range(
|
||||
max(positions.start, placeholder.start),
|
||||
|
@ -109,7 +109,8 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
|
||||
"""
|
||||
|
||||
|
||||
class PlaceholderRange(TypedDict):
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderRange:
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
|
||||
@ -121,8 +122,8 @@ class PlaceholderRange(TypedDict):
|
||||
|
||||
.. code-block::
|
||||
|
||||
A: { "offset": 0, "length": 4 }
|
||||
B: { "offset": 5, "length": 4 }
|
||||
A: PlaceholderRange(offset=0, length=4)
|
||||
B: PlaceholderRange(offset=5, length=4)
|
||||
"""
|
||||
|
||||
offset: int
|
||||
@ -131,6 +132,31 @@ class PlaceholderRange(TypedDict):
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
is_embed: Optional[torch.Tensor] = None
|
||||
"""
|
||||
A boolean mask of shape `(length,)` indicating which positions
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.is_embed is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if not (self.offset, self.length) == (other.offset, other.length):
|
||||
return False
|
||||
|
||||
if self.is_embed is None:
|
||||
return other.is_embed is None
|
||||
if other.is_embed is None:
|
||||
return self.is_embed is None
|
||||
|
||||
return nested_tensors_equal(self.is_embed, other.is_embed)
|
||||
|
||||
|
||||
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
||||
tuple[torch.Tensor, ...]]
|
||||
|
@ -108,16 +108,46 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
full: _S
|
||||
"""The full content."""
|
||||
|
||||
features: _S
|
||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
|
||||
"""
|
||||
The part of the content that corresponds to feature placeholders;
|
||||
this will be replaced by the output of the vision encoder during model
|
||||
inference.
|
||||
Given :attr:`full`, return a boolean mask of shape `(len(full),)`
|
||||
indicating which positions of `full` to assign embeddings to.
|
||||
|
||||
`None` (default) means to assign embeddings to all positions of `full`.
|
||||
|
||||
The embeddings are obtained by calling
|
||||
:class:`SupportsMultiModal.get_multimodal_embeddings`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
|
||||
return PromptUpdateDetails(full=seq, features=seq)
|
||||
return PromptUpdateDetails(full=seq)
|
||||
|
||||
@staticmethod
|
||||
def select_text(
|
||||
seq: _S,
|
||||
embed_text: str,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
|
||||
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
|
||||
embed_token_ids = encode_tokens(full.tokenizer, embed_text)
|
||||
|
||||
return torch.isin(
|
||||
torch.tensor(full.token_ids),
|
||||
torch.tensor(embed_token_ids),
|
||||
)
|
||||
|
||||
return PromptUpdateDetails(full=seq, is_embed=is_embed)
|
||||
|
||||
@staticmethod
|
||||
def select_token_id(
|
||||
seq: _S,
|
||||
embed_token_id: int,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
return PromptUpdateDetails(
|
||||
full=seq,
|
||||
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
|
||||
)
|
||||
|
||||
|
||||
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
|
||||
@ -406,7 +436,7 @@ class _BoundPromptSequence:
|
||||
@dataclass
|
||||
class _BoundPromptContent:
|
||||
full: _BoundPromptSequence
|
||||
features: _BoundPromptSequence
|
||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -466,10 +496,8 @@ class BoundPromptUpdate:
|
||||
|
||||
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
content.full)
|
||||
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
content.features)
|
||||
bound_content = _BoundPromptContent(full=bound_full,
|
||||
features=bound_features)
|
||||
is_embed=content.is_embed)
|
||||
|
||||
if cache_key is not None:
|
||||
self._content_cache[cache_key] = bound_content
|
||||
@ -605,15 +633,19 @@ class PlaceholderFeaturesInfo:
|
||||
item_idx: int
|
||||
start_idx: int
|
||||
tokens: list[int]
|
||||
is_embed: Optional[torch.Tensor]
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self.tokens)
|
||||
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
# TODO: Is it worth it to optimize this by stripping the
|
||||
# leading and ending positions where `is_embed=False`?
|
||||
return PlaceholderRange(
|
||||
offset=self.start_idx,
|
||||
length=self.length,
|
||||
is_embed=self.is_embed,
|
||||
)
|
||||
|
||||
|
||||
@ -806,22 +838,17 @@ def _iter_placeholders(
|
||||
continue
|
||||
|
||||
if prompt[start_idx:end_idx_full] == content_tokens_full:
|
||||
content_tokens_feat = content.features.token_ids
|
||||
content_is_embed = content.is_embed
|
||||
if content_is_embed is not None:
|
||||
content_is_embed = content_is_embed(content.full)
|
||||
|
||||
try:
|
||||
match = next(
|
||||
iter_token_matches(content_tokens_full,
|
||||
content_tokens_feat))
|
||||
yield PlaceholderFeaturesInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx + match.start_idx,
|
||||
tokens=content_tokens_feat,
|
||||
)
|
||||
except StopIteration:
|
||||
raise AssertionError(
|
||||
f"{content_tokens_feat=} should be a "
|
||||
f"subsequence of {content_tokens_full=}") from None
|
||||
yield PlaceholderFeaturesInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx,
|
||||
tokens=content_tokens_full,
|
||||
is_embed=content_is_embed,
|
||||
)
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = end_idx_full
|
||||
|
@ -180,7 +180,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
total_placeholders_by_modality = {
|
||||
modality: sum(item["length"] for item in placeholders)
|
||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
expected_placeholders_by_modality = {
|
||||
|
@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
|
||||
all_items.append((modality, placeholder, hash_value))
|
||||
|
||||
# Sort all items by offset
|
||||
all_items.sort(key=lambda x: x[1]['offset'])
|
||||
all_items.sort(key=lambda x: x[1].offset)
|
||||
|
||||
# Split into separate lists
|
||||
sorted_modalities = [item[0] for item in all_items]
|
||||
|
@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
# Note that we assume mm_positions is sorted by offset.
|
||||
# We do not need to check all mm inputs if the start token index is out of
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
if mm_positions[-1]["offset"] + mm_positions[-1][
|
||||
"length"] < start_token_idx:
|
||||
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
curr_mm_idx = start_mm_idx
|
||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||
assert mm_hashes[curr_mm_idx] is not None
|
||||
offset = mm_positions[curr_mm_idx]["offset"]
|
||||
length = mm_positions[curr_mm_idx]["length"]
|
||||
offset = mm_positions[curr_mm_idx].offset
|
||||
length = mm_positions[curr_mm_idx].length
|
||||
if end_token_idx > offset:
|
||||
if start_token_idx > offset + length:
|
||||
# This block has passed the current mm input.
|
||||
|
@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
|
||||
assert mm_positions is not None
|
||||
assert len(mm_positions) > 0
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
|
||||
if cached_encoder_input_ids:
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions["offset"]
|
||||
num_tokens = mm_positions["length"]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
|
@ -121,7 +121,7 @@ class Request:
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_positions)
|
||||
num_tokens = self.mm_positions[input_id]["length"]
|
||||
num_tokens = self.mm_positions[input_id].length
|
||||
return num_tokens
|
||||
|
||||
@property
|
||||
|
@ -19,7 +19,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import sanity_check_mm_encoder_outputs
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@ -829,19 +831,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_inputs: list[MultiModalKwargs] = []
|
||||
req_input_ids: list[tuple[str, int]] = []
|
||||
mm_inputs = list[MultiModalKwargs]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
for input_id in encoder_input_ids:
|
||||
for input_id, pos_info in zip(
|
||||
encoder_input_ids,
|
||||
req_state.mm_positions,
|
||||
):
|
||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||
req_input_ids.append((req_id, input_id))
|
||||
req_ids_pos.append((req_id, input_id, pos_info))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
@ -877,16 +882,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs.append(output)
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
self.encoder_cache[req_id][input_id] = output
|
||||
|
||||
def _gather_encoder_outputs(
|
||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
@ -894,8 +906,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
@ -917,8 +929,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||
return encoder_outputs
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@ -983,10 +1003,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_encoder(scheduler_output)
|
||||
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
encoder_outputs = []
|
||||
mm_embeds = []
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||
@ -1008,9 +1028,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
if encoder_outputs:
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, encoder_outputs)
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
|
@ -19,7 +19,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -36,7 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from .utils import sanity_check_mm_encoder_outputs
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -507,19 +509,47 @@ class TPUModelRunner:
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
def _scatter_placeholders(
|
||||
self,
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
def _gather_placeholders(
|
||||
self,
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_inputs: list[MultiModalKwargs] = []
|
||||
req_input_ids: list[tuple[str, int]] = []
|
||||
mm_inputs = list[MultiModalKwargs]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
for input_id in encoder_input_ids:
|
||||
for input_id, pos_info in zip(
|
||||
encoder_input_ids,
|
||||
req_state.mm_positions,
|
||||
):
|
||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||
req_input_ids.append((req_id, input_id))
|
||||
req_ids_pos.append((req_id, input_id, pos_info))
|
||||
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
@ -555,16 +585,23 @@ class TPUModelRunner:
|
||||
encoder_outputs.append(output)
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
self.encoder_cache[req_id][input_id] = output
|
||||
|
||||
def _gather_encoder_outputs(
|
||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
@ -572,8 +609,8 @@ class TPUModelRunner:
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info["offset"]
|
||||
num_encoder_tokens = pos_info["length"]
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
@ -595,8 +632,16 @@ class TPUModelRunner:
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||
return encoder_outputs
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
@ -612,10 +657,10 @@ class TPUModelRunner:
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_encoder(scheduler_output)
|
||||
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
encoder_outputs = []
|
||||
mm_embeds = []
|
||||
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
@ -623,9 +668,9 @@ class TPUModelRunner:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
if encoder_outputs:
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
self.input_ids, encoder_outputs)
|
||||
self.input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
||||
input_ids = None
|
||||
|
@ -1,4 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
|
||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||
"instead. This is most likely due to incorrect implementation "
|
||||
"of the model's `get_multimodal_embeddings` method.")
|
||||
|
||||
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
||||
the placeholder tokens.
|
||||
|
||||
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
|
||||
|
||||
Args:
|
||||
embeds: The multimodal embeddings.
|
||||
Shape: `(num_embeds, embed_dim)`
|
||||
is_embed: A boolean mask indicating which positions in the placeholder
|
||||
tokens need to be filled with multimodal embeddings.
|
||||
Shape: `(num_placeholders, num_embeds)`
|
||||
"""
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reconstructs the embeddings from the placeholder tokens.
|
||||
|
||||
This is the operation of :func:`scatter_mm_placeholders`.
|
||||
"""
|
||||
if is_embed is None:
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
Loading…
x
Reference in New Issue
Block a user