Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
This commit is contained in:
parent
f5722a5052
commit
af51d80fa1
@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
To assign the vision embeddings to only the image tokens, instead of a string
|
To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`
|
||||||
you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
|
with different `full` and `feature` attributes:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
hf_config = self.info.get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int):
|
|||||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||||
[_NEWLINE_TOKEN_ID]) * nrows
|
[_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
return PromptUpdateDetails(
|
||||||
image_tokens + [bos_token_id],
|
full=image_tokens + [bos_token_id],
|
||||||
embed_token_id=_IMAGE_TOKEN_ID,
|
features=image_tokens,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -914,9 +914,9 @@ def _get_prompt_updates(
|
|||||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||||
[_NEWLINE_TOKEN_ID]) * nrows
|
[_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
return PromptUpdateDetails(
|
||||||
image_tokens + [bos_token_id],
|
full=image_tokens + [bos_token_id],
|
||||||
embed_token_id=_IMAGE_TOKEN_ID,
|
features=image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
@ -989,6 +989,9 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||||
|
|
||||||
:::{important}
|
:::{important}
|
||||||
|
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
|
||||||
|
`pip install git+https://github.com/huggingface/transformers`.
|
||||||
|
|
||||||
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
|
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
|
||||||
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
|
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
|
||||||
:::
|
:::
|
||||||
|
@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=5,
|
||||||
limit_mm_per_prompt={"audio": audio_count},
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,10 +55,7 @@ def server(request, audio_assets):
|
|||||||
for key, value in request.param.items()
|
for key, value in request.param.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME,
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
args,
|
|
||||||
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
|
|
||||||
"30"}) as remote_server:
|
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
"cherry_blossom": "<image>What is the season?", # noqa: E501
|
"cherry_blossom": "<image>What is the season?", # noqa: E501
|
||||||
}),
|
}),
|
||||||
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
|
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
|
||||||
max_model_len=4096,
|
max_model_len=8192,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
auto_cls=AutoModelForImageTextToText,
|
auto_cls=AutoModelForImageTextToText,
|
||||||
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
|
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
|
||||||
|
@ -176,8 +176,6 @@ def test_chat(
|
|||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
load_format="mistral",
|
|
||||||
config_format="mistral",
|
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
@ -200,14 +198,22 @@ def test_chat(
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@pytest.mark.parametrize("prompt,expected_ranges",
|
@pytest.mark.parametrize(
|
||||||
[(_create_engine_inputs_hf(IMG_URLS[:1]),
|
"prompt,expected_ranges",
|
||||||
[PlaceholderRange(offset=11, length=494)]),
|
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
|
||||||
(_create_engine_inputs_hf(IMG_URLS[1:4]), [
|
"offset": 11,
|
||||||
PlaceholderRange(offset=11, length=266),
|
"length": 494
|
||||||
PlaceholderRange(offset=277, length=1056),
|
}]),
|
||||||
PlaceholderRange(offset=1333, length=418)
|
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
|
||||||
])])
|
"offset": 11,
|
||||||
|
"length": 266
|
||||||
|
}, {
|
||||||
|
"offset": 277,
|
||||||
|
"length": 1056
|
||||||
|
}, {
|
||||||
|
"offset": 1333,
|
||||||
|
"length": 418
|
||||||
|
}])])
|
||||||
def test_multi_modal_placeholders(vllm_runner, prompt,
|
def test_multi_modal_placeholders(vllm_runner, prompt,
|
||||||
expected_ranges: list[PlaceholderRange],
|
expected_ranges: list[PlaceholderRange],
|
||||||
monkeypatch) -> None:
|
monkeypatch) -> None:
|
||||||
|
@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
|
|||||||
first_placeholder = image_placeholders[0]
|
first_placeholder = image_placeholders[0]
|
||||||
|
|
||||||
# NOTE: There is a BOS token
|
# NOTE: There is a BOS token
|
||||||
assert first_placeholder.offset == 1
|
assert first_placeholder["offset"] == 1
|
||||||
assert first_placeholder.length == (
|
assert first_placeholder["length"] == (
|
||||||
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
|
|||||||
|
|
||||||
first_placeholder = image_placeholders[0]
|
first_placeholder = image_placeholders[0]
|
||||||
|
|
||||||
assert first_placeholder.offset == 0
|
assert first_placeholder["offset"] == 0
|
||||||
assert first_placeholder.length == len(
|
assert first_placeholder["length"] == len(
|
||||||
processed_inputs["prompt_token_ids"]) // num_imgs
|
processed_inputs["prompt_token_ids"]) // num_imgs
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
failed_size_excs.append((image_size, exc))
|
failed_size_excs.append((image_size, exc))
|
||||||
|
@ -277,9 +277,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
|
||||||
max_transformers_version="4.48", # noqa: E501
|
|
||||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
|
||||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||||
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
@ -785,7 +785,6 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=6,
|
start_idx=6,
|
||||||
tokens=[32000, 32000],
|
tokens=[32000, 32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_4": [
|
"pattern_4": [
|
||||||
@ -794,7 +793,6 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=3,
|
start_idx=3,
|
||||||
tokens=[32000],
|
tokens=[32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -809,14 +807,12 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=1,
|
start_idx=1,
|
||||||
tokens=[32000, 32000],
|
tokens=[32000, 32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
PlaceholderFeaturesInfo(
|
PlaceholderFeaturesInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=1,
|
item_idx=1,
|
||||||
start_idx=5,
|
start_idx=5,
|
||||||
tokens=[32000, 32000],
|
tokens=[32000, 32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_3": [
|
"pattern_3": [
|
||||||
@ -825,7 +821,6 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=7,
|
start_idx=7,
|
||||||
tokens=[1550, 918, 1550],
|
tokens=[1550, 918, 1550],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
# No match for pattern_4 as it has lower priority than pattern_1
|
# No match for pattern_4 as it has lower priority than pattern_1
|
||||||
@ -840,14 +835,12 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=1,
|
start_idx=1,
|
||||||
tokens=[32000, 32000],
|
tokens=[32000, 32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
PlaceholderFeaturesInfo(
|
PlaceholderFeaturesInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=1,
|
item_idx=1,
|
||||||
start_idx=3,
|
start_idx=3,
|
||||||
tokens=[32000, 32000],
|
tokens=[32000, 32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_4": [
|
"pattern_4": [
|
||||||
@ -856,7 +849,6 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=5,
|
start_idx=5,
|
||||||
tokens=[32000],
|
tokens=[32000],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_3": [
|
"pattern_3": [
|
||||||
@ -865,7 +857,6 @@ def test_find_update_tokens(
|
|||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=6,
|
start_idx=6,
|
||||||
tokens=[1550, 918, 1550],
|
tokens=[1550, 918, 1550],
|
||||||
is_embed=None,
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256
|
from vllm.utils import sha256
|
||||||
# disable yapf here as it formats differently than isort such that both fail
|
# disable yapf here as it formats differently than isort such that both fail
|
||||||
@ -158,10 +158,13 @@ def test_generate_block_hash_extra_keys():
|
|||||||
request = make_request(
|
request = make_request(
|
||||||
request_id=0,
|
request_id=0,
|
||||||
prompt_token_ids=[_ for _ in range(20)],
|
prompt_token_ids=[_ for _ in range(20)],
|
||||||
mm_positions=[
|
mm_positions=[{
|
||||||
PlaceholderRange(offset=0, length=5),
|
"offset": 0,
|
||||||
PlaceholderRange(offset=10, length=5),
|
"length": 5
|
||||||
],
|
}, {
|
||||||
|
"offset": 10,
|
||||||
|
"length": 5
|
||||||
|
}],
|
||||||
mm_hashes=["hash1", "hash2"],
|
mm_hashes=["hash1", "hash2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -219,10 +222,13 @@ def test_hash_request_tokens(hash_fn):
|
|||||||
request = make_request(
|
request = make_request(
|
||||||
request_id=0,
|
request_id=0,
|
||||||
prompt_token_ids=[_ for _ in range(6)],
|
prompt_token_ids=[_ for _ in range(6)],
|
||||||
mm_positions=[
|
mm_positions=[{
|
||||||
PlaceholderRange(offset=0, length=3),
|
"offset": 0,
|
||||||
PlaceholderRange(offset=3, length=3),
|
"length": 3
|
||||||
],
|
}, {
|
||||||
|
"offset": 3,
|
||||||
|
"length": 3
|
||||||
|
}],
|
||||||
mm_hashes=["hash1", "hash2"],
|
mm_hashes=["hash1", "hash2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -247,19 +253,25 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
|||||||
request1 = make_request(
|
request1 = make_request(
|
||||||
request_id=0,
|
request_id=0,
|
||||||
prompt_token_ids=[_ for _ in range(6)],
|
prompt_token_ids=[_ for _ in range(6)],
|
||||||
mm_positions=[
|
mm_positions=[{
|
||||||
PlaceholderRange(offset=0, length=3),
|
"offset": 0,
|
||||||
PlaceholderRange(offset=3, length=3),
|
"length": 3
|
||||||
],
|
}, {
|
||||||
|
"offset": 3,
|
||||||
|
"length": 3
|
||||||
|
}],
|
||||||
mm_hashes=["hash1", "hash2"],
|
mm_hashes=["hash1", "hash2"],
|
||||||
)
|
)
|
||||||
request2 = make_request(
|
request2 = make_request(
|
||||||
request_id=1,
|
request_id=1,
|
||||||
prompt_token_ids=[_ for _ in range(6)],
|
prompt_token_ids=[_ for _ in range(6)],
|
||||||
mm_positions=[
|
mm_positions=[{
|
||||||
PlaceholderRange(offset=0, length=3),
|
"offset": 0,
|
||||||
PlaceholderRange(offset=3, length=3),
|
"length": 3
|
||||||
],
|
}, {
|
||||||
|
"offset": 3,
|
||||||
|
"length": 3
|
||||||
|
}],
|
||||||
mm_hashes=["hash3", "hash2"],
|
mm_hashes=["hash3", "hash2"],
|
||||||
)
|
)
|
||||||
block_size = 3
|
block_size = 3
|
||||||
|
@ -27,7 +27,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo,
|
BaseProcessingInfo,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
encode_tokens)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -35,6 +35,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
|
|
||||||
class AyaVisionImagePixelInputs(TypedDict):
|
class AyaVisionImagePixelInputs(TypedDict):
|
||||||
@ -50,6 +51,13 @@ class AyaVisionImagePixelInputs(TypedDict):
|
|||||||
num_patches: torch.Tensor
|
num_patches: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AyaVisionMultiModalProjector(nn.Module):
|
class AyaVisionMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
@ -127,20 +135,21 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
|
|||||||
def get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
hf_processor = self.get_hf_processor()
|
hf_processor = self.get_hf_processor()
|
||||||
image_processor = hf_processor.image_processor
|
image_processor = hf_processor.image_processor
|
||||||
|
|
||||||
image_size = self.get_image_size_with_most_features()
|
image_size = self.get_image_size_with_most_features()
|
||||||
|
tokenizer = hf_processor.tokenizer
|
||||||
num_patches = self.get_num_patches(
|
num_patches = self.get_num_patches(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
size=image_processor.size,
|
size=image_processor.size,
|
||||||
min_patches=image_processor.min_patches,
|
min_patches=image_processor.min_patches,
|
||||||
max_patches=image_processor.max_patches,
|
max_patches=image_processor.max_patches)
|
||||||
|
image_string = hf_processor._prompt_split_image(num_patches)
|
||||||
|
x = encode_tokens(
|
||||||
|
tokenizer,
|
||||||
|
image_string,
|
||||||
|
add_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
return len(x)
|
||||||
img_patches_per_tile = (hf_processor.img_size //
|
|
||||||
hf_processor.patch_size)**2
|
|
||||||
|
|
||||||
return num_patches * img_patches_per_tile
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None}
|
return {"image": None}
|
||||||
@ -212,6 +221,7 @@ class AyaVisionMultiModalProcessor(
|
|||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||||
image_processor = hf_processor.image_processor
|
image_processor = hf_processor.image_processor
|
||||||
|
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
|
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
|
||||||
if (images :=
|
if (images :=
|
||||||
mm_data.get("images")) is not None and '<image>' in prompt:
|
mm_data.get("images")) is not None and '<image>' in prompt:
|
||||||
@ -224,7 +234,6 @@ class AyaVisionMultiModalProcessor(
|
|||||||
parsed_images.get_image_size(i)
|
parsed_images.get_image_size(i)
|
||||||
for i in range(len(parsed_images))
|
for i in range(len(parsed_images))
|
||||||
]
|
]
|
||||||
|
|
||||||
num_patches = [
|
num_patches = [
|
||||||
self.info.get_num_patches(
|
self.info.get_num_patches(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
@ -234,6 +243,20 @@ class AyaVisionMultiModalProcessor(
|
|||||||
max_patches=image_processor.max_patches)
|
max_patches=image_processor.max_patches)
|
||||||
for image_size in image_sizes
|
for image_size in image_sizes
|
||||||
]
|
]
|
||||||
|
image_tokens_list = [
|
||||||
|
hf_processor._prompt_split_image(num_patch)
|
||||||
|
for num_patch in num_patches
|
||||||
|
]
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_token_ids = [
|
||||||
|
tokenizer.encode(image_tokens, add_special_tokens=False)
|
||||||
|
for image_tokens in image_tokens_list
|
||||||
|
]
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(image_repl_tokens) == hf_config.image_token_index
|
||||||
|
for image_repl_tokens in image_token_ids
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
processed_outputs["num_patches"] = torch.tensor(num_patches)
|
processed_outputs["num_patches"] = torch.tensor(num_patches)
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
@ -248,6 +271,7 @@ class AyaVisionMultiModalProcessor(
|
|||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_patches),
|
"image", num_patches),
|
||||||
num_patches=MultiModalFieldConfig.batched("image"),
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -259,7 +283,6 @@ class AyaVisionMultiModalProcessor(
|
|||||||
) -> Sequence[PromptUpdate]:
|
) -> Sequence[PromptUpdate]:
|
||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_token = hf_processor.image_token
|
image_token = hf_processor.image_token
|
||||||
img_patch_token = hf_processor.img_patch_token
|
|
||||||
image_processor = hf_processor.image_processor
|
image_processor = hf_processor.image_processor
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
def get_replacement(item_idx: int):
|
||||||
@ -271,11 +294,8 @@ class AyaVisionMultiModalProcessor(
|
|||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
size=image_processor.size,
|
size=image_processor.size,
|
||||||
min_patches=image_processor.min_patches,
|
min_patches=image_processor.min_patches,
|
||||||
max_patches=image_processor.max_patches,
|
max_patches=image_processor.max_patches)
|
||||||
)
|
return hf_processor._prompt_split_image(num_patches=num_patches)
|
||||||
repl = hf_processor._prompt_split_image(num_patches=num_patches)
|
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl, img_patch_token)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -404,6 +424,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
|
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
num_patches = kwargs.pop("num_patches", None)
|
num_patches = kwargs.pop("num_patches", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
assert image_embeds is None, "Aya Vision does not support image_embeds."
|
assert image_embeds is None, "Aya Vision does not support image_embeds."
|
||||||
|
|
||||||
@ -415,13 +436,18 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError("Incorrect type of num_patches. "
|
raise ValueError("Incorrect type of num_patches. "
|
||||||
f"Got type: {type(num_patches)}")
|
f"Got type: {type(num_patches)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||||
num_patches = flatten_bn(num_patches, concat=True)
|
num_patches = flatten_bn(num_patches, concat=True)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
return AyaVisionImagePixelInputs(
|
return AyaVisionImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=self._validate_pixel_values(pixel_values),
|
pixel_values=self._validate_pixel_values(pixel_values),
|
||||||
num_patches=num_patches,
|
num_patches=num_patches,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
@ -429,8 +455,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
image_features = self._process_image_input(image_input, **kwargs)
|
||||||
return self._process_image_input(image_input, **kwargs)
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -442,9 +471,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
multimodal_embeddings=multimodal_embeddings,
|
multimodal_embeddings=select_patch_features(
|
||||||
placeholder_token_id=self.config.image_token_index,
|
multimodal_embeddings),
|
||||||
)
|
placeholder_token_id=self.config.image_token_index)
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
@ -162,9 +162,9 @@ class ChameleonMultiModalProcessor(
|
|||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
target=[image_token_id],
|
target=[image_token_id],
|
||||||
replacement=PromptUpdateDetails.select_token_id(
|
replacement=PromptUpdateDetails(
|
||||||
[image_start_id] + image_tokens + [image_end_id],
|
full=([image_start_id] + image_tokens + [image_end_id]),
|
||||||
embed_token_id=image_token_id,
|
features=image_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
""" PyTorch Fuyu model."""
|
""" PyTorch Fuyu model."""
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Literal, Optional, Set, Tuple, TypedDict
|
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
# Cannot find the following 2 numbers from hf config.
|
# Cannot find the following 2 numbers from hf config.
|
||||||
_IMAGE_TOKEN_ID = 71011
|
_IMAGE_TOKEN_ID = 71011
|
||||||
@ -65,6 +66,14 @@ class FuyuImagePatchInputs(TypedDict):
|
|||||||
flattened just like `flat_data`.
|
flattened just like `flat_data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FuyuProcessingInfo(BaseProcessingInfo):
|
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
@ -85,7 +94,15 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
return {"image": self.get_max_image_tokens()}
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
)
|
||||||
|
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||||
|
|
||||||
|
return {"image": max_image_tokens}
|
||||||
|
|
||||||
def get_image_feature_grid_size(
|
def get_image_feature_grid_size(
|
||||||
self,
|
self,
|
||||||
@ -111,32 +128,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
|||||||
nrows = math.ceil(image_height / patch_height)
|
nrows = math.ceil(image_height / patch_height)
|
||||||
return ncols, nrows
|
return ncols, nrows
|
||||||
|
|
||||||
def get_num_image_tokens(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
) -> int:
|
|
||||||
ncols, nrows = self.get_image_feature_grid_size(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ncols * nrows
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
image_processor = self.get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
return ImageSize(width=image_processor.size["width"],
|
return ImageSize(width=image_processor.size["width"],
|
||||||
height=image_processor.size["height"])
|
height=image_processor.size["height"])
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
|
||||||
|
|
||||||
return self.get_num_image_tokens(
|
|
||||||
image_width=target_width,
|
|
||||||
image_height=target_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||||
|
|
||||||
@ -196,6 +192,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
|
|
||||||
processed_outputs["image_patches"] = image_patches[0]
|
processed_outputs["image_patches"] = image_patches[0]
|
||||||
|
|
||||||
|
# get patch grid size for each image
|
||||||
|
embed_is_patch = []
|
||||||
|
for image in images:
|
||||||
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
|
image_width=image.width,
|
||||||
|
image_height=image.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
|
embed_is_patch.append(mask)
|
||||||
|
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _apply_hf_processor_tokens_only(
|
def _apply_hf_processor_tokens_only(
|
||||||
@ -215,7 +224,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
return dict(image_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"))
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -242,9 +252,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||||
[_NEWLINE_TOKEN_ID]) * nrows
|
[_NEWLINE_TOKEN_ID]) * nrows
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
return PromptUpdateDetails(
|
||||||
image_tokens + [bos_token_id],
|
full=image_tokens + [bos_token_id],
|
||||||
embed_token_id=_IMAGE_TOKEN_ID,
|
features=image_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -319,13 +329,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of image patches. "
|
raise ValueError("Incorrect type of image patches. "
|
||||||
f"Got type: {type(image_patches)}")
|
f"Got type: {type(image_patches)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
image_patches_flat = flatten_bn(image_patches)
|
image_patches_flat = flatten_bn(image_patches)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return FuyuImagePatchInputs(
|
return FuyuImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
flat_data=self._validate_pixel_values(
|
flat_data=self._validate_pixel_values(
|
||||||
flatten_bn(image_patches_flat, concat=True)),
|
flatten_bn(image_patches_flat, concat=True)),
|
||||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -347,7 +364,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -357,11 +379,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids, inputs_embeds,
|
||||||
inputs_embeds,
|
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
|
||||||
multimodal_embeddings,
|
|
||||||
_IMAGE_TOKEN_ID,
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptTargetMatch,
|
PromptReplacement, PromptTargetMatch,
|
||||||
PromptUpdate, PromptUpdateDetails,
|
PromptUpdate, PromptUpdateDetails,
|
||||||
find_mm_placeholders,
|
encode_tokens, find_mm_placeholders,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
@ -36,6 +36,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -53,6 +54,14 @@ class Gemma3ImagePixelInputs(TypedDict):
|
|||||||
num_patches: torch.Tensor
|
num_patches: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||||
|
|
||||||
@ -174,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
if processor is None:
|
if processor is None:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
boi_token = processor.boi_token
|
image_token = processor.boi_token
|
||||||
|
|
||||||
num_crops = self.get_num_crops(
|
num_crops = self.get_num_crops(
|
||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
@ -183,21 +192,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_crops == 0:
|
if num_crops == 0:
|
||||||
image_text = boi_token
|
image_text = image_token
|
||||||
else:
|
else:
|
||||||
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
|
crops_image_tokens = " ".join(image_token
|
||||||
|
for _ in range(num_crops))
|
||||||
image_text = (
|
image_text = (
|
||||||
f"Here is the original image {boi_token} and here are some "
|
f"Here is the original image {image_token} and here are some "
|
||||||
f"crops to help you see better {crops_image_tokens}")
|
f"crops to help you see better {crops_image_tokens}")
|
||||||
|
|
||||||
repl_full = image_text.replace(boi_token,
|
repl_full = image_text.replace(image_token,
|
||||||
processor.full_image_sequence)
|
processor.full_image_sequence)
|
||||||
|
repl_features = repl_full.strip("\n")
|
||||||
|
|
||||||
tokenizer = processor.tokenizer
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
vocab = tokenizer.get_vocab()
|
|
||||||
image_token_id = vocab[tokenizer.image_token]
|
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
|
|
||||||
|
|
||||||
def get_num_image_tokens(
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
@ -206,17 +213,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
image_height: int,
|
image_height: int,
|
||||||
processor: Optional[Gemma3Processor],
|
processor: Optional[Gemma3Processor],
|
||||||
) -> int:
|
) -> int:
|
||||||
if processor is None:
|
tokenizer = self.get_tokenizer()
|
||||||
processor = self.get_hf_processor()
|
image_repl = self.get_image_repl(
|
||||||
|
|
||||||
num_crops = self.get_num_crops(
|
|
||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
image_height=image_height,
|
image_height=image_height,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
image_seq_len = processor.image_seq_length
|
|
||||||
|
|
||||||
return (num_crops + 1) * image_seq_len
|
image_repl_tokens = encode_tokens(
|
||||||
|
tokenizer,
|
||||||
|
image_repl.features,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
return len(image_repl_tokens)
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
@ -292,6 +301,28 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
]
|
]
|
||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||||
|
|
||||||
|
image_repl_features = [
|
||||||
|
self.info.get_image_repl(image_width=size.width,
|
||||||
|
image_height=size.height,
|
||||||
|
processor=hf_processor).features
|
||||||
|
for size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||||
|
for image_repl in image_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
image_token_id = vocab[tokenizer.image_token]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(image_repl_tokens) == image_token_id
|
||||||
|
for image_repl_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
num_crops = [
|
num_crops = [
|
||||||
self.info.get_num_crops(image_width=size.width,
|
self.info.get_num_crops(image_width=size.width,
|
||||||
image_height=size.height,
|
image_height=size.height,
|
||||||
@ -313,6 +344,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_crops + 1),
|
"image", num_crops + 1),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
num_crops=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -422,7 +454,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
item_idx=p.item_idx,
|
item_idx=p.item_idx,
|
||||||
start_idx=repl_orig_idxs[p.start_idx],
|
start_idx=repl_orig_idxs[p.start_idx],
|
||||||
tokens=p.tokens,
|
tokens=p.tokens,
|
||||||
is_embed=p.is_embed,
|
|
||||||
) for p in placeholders
|
) for p in placeholders
|
||||||
]
|
]
|
||||||
for modality, placeholders in repls.items()
|
for modality, placeholders in repls.items()
|
||||||
@ -541,6 +572,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
@ -554,13 +586,19 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
raise ValueError("Incorrect type of num_crops. "
|
raise ValueError("Incorrect type of num_crops. "
|
||||||
f"Got type: {type(num_crops)}")
|
f"Got type: {type(num_crops)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||||
num_crops = flatten_bn(num_crops, concat=True)
|
num_crops = flatten_bn(num_crops, concat=True)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return Gemma3ImagePixelInputs(
|
return Gemma3ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=self._validate_pixel_values(pixel_values),
|
pixel_values=self._validate_pixel_values(pixel_values),
|
||||||
num_patches=num_crops + 1,
|
num_patches=num_crops + 1,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
@ -597,7 +635,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -609,7 +652,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.config.image_token_index,
|
self.config.image_token_index,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
|
|||||||
repl_features = IMG_CONTEXT * feature_size
|
repl_features = IMG_CONTEXT * feature_size
|
||||||
repl_full = IMG_START + repl_features + IMG_END
|
repl_full = IMG_START + repl_features + IMG_END
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
|
|
||||||
def resolve_min_max_num(
|
def resolve_min_max_num(
|
||||||
self,
|
self,
|
||||||
|
@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
MultiModalDataItems,
|
MultiModalDataItems,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate,
|
||||||
PromptUpdateDetails)
|
encode_tokens)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -54,6 +54,7 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
|||||||
from .llama import LlamaModel
|
from .llama import LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImagePixelInputs(TypedDict):
|
class Idefics3ImagePixelInputs(TypedDict):
|
||||||
@ -68,6 +69,14 @@ class Idefics3ImagePixelInputs(TypedDict):
|
|||||||
num_patches: torch.Tensor
|
num_patches: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImageEmbeddingInputs(TypedDict):
|
class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -77,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
|
|||||||
`hidden_size` must match the hidden size of language model backbone.
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||||
|
|
||||||
@ -258,16 +275,19 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
|||||||
image_height: int,
|
image_height: int,
|
||||||
processor: Optional[Idefics3Processor],
|
processor: Optional[Idefics3Processor],
|
||||||
) -> int:
|
) -> int:
|
||||||
if processor is None:
|
tokenizer = self.get_tokenizer()
|
||||||
processor = self.get_hf_processor()
|
image_repl = self.get_image_repl(
|
||||||
|
|
||||||
num_patches = self.get_num_patches(
|
|
||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
image_height=image_height,
|
image_height=image_height,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return num_patches * processor.image_seq_len
|
image_repl_tokens = encode_tokens(
|
||||||
|
tokenizer,
|
||||||
|
image_repl,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
return len(image_repl_tokens)
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
@ -344,6 +364,28 @@ class Idefics3MultiModalProcessor(
|
|||||||
]
|
]
|
||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||||
|
|
||||||
|
image_repl_features = [
|
||||||
|
self.info.get_image_repl(image_width=size.width,
|
||||||
|
image_height=size.height,
|
||||||
|
processor=hf_processor)
|
||||||
|
for size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||||
|
for image_repl in image_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
image_token_id = vocab[hf_processor.image_token.content]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(image_repl_tokens) == image_token_id
|
||||||
|
for image_repl_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
num_patches = [
|
num_patches = [
|
||||||
self.info.get_num_patches(
|
self.info.get_num_patches(
|
||||||
image_width=size.width,
|
image_width=size.width,
|
||||||
@ -373,6 +415,7 @@ class Idefics3MultiModalProcessor(
|
|||||||
"image", num_patches),
|
"image", num_patches),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
num_patches=MultiModalFieldConfig.batched("image"),
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -384,22 +427,17 @@ class Idefics3MultiModalProcessor(
|
|||||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_token = hf_processor.image_token.content
|
image_token = hf_processor.image_token.content
|
||||||
|
|
||||||
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
|
def get_replacement_idefics3(item_idx: int) -> str:
|
||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
|
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
image_repl = self.info.get_image_repl(
|
return self.info.get_image_repl(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
processor=hf_processor,
|
processor=hf_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(
|
|
||||||
image_repl,
|
|
||||||
embed_text=image_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
@ -637,6 +675,13 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
@ -645,6 +690,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return Idefics3ImageEmbeddingInputs(
|
return Idefics3ImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=flatten_bn(image_embeds, concat=True),
|
data=flatten_bn(image_embeds, concat=True),
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@ -672,6 +718,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
pixel_values=self._validate_pixel_values(pixel_values),
|
pixel_values=self._validate_pixel_values(pixel_values),
|
||||||
pixel_attention_mask=pixel_attention_mask,
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
num_patches=num_patches,
|
num_patches=num_patches,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
@ -707,7 +754,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -719,7 +771,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.config.image_token_id,
|
self.config.image_token_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -39,6 +39,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -59,6 +60,14 @@ class InternVLImagePixelInputs(TypedDict):
|
|||||||
num_patches: torch.Tensor
|
num_patches: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class InternVLImageEmbeddingInputs(TypedDict):
|
class InternVLImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -410,12 +419,24 @@ class BaseInternVLProcessor(ABC):
|
|||||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
image_token_id = self.image_token_id
|
||||||
|
|
||||||
|
embed_is_patch = list[torch.Tensor]()
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst:
|
for pixel_values in pixel_values_lst:
|
||||||
num_patches = pixel_values.shape[0]
|
num_patches = pixel_values.shape[0]
|
||||||
feature_size = num_patches * self.num_image_token
|
feature_size = num_patches * self.num_image_token
|
||||||
|
|
||||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||||
|
feature_tokens = tokenizer.encode(image_repl.features,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||||
|
embed_is_patch.append(
|
||||||
|
torch.tensor(feature_tokens) == image_token_id)
|
||||||
|
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
text_inputs = self.tokenizer(text)
|
text_inputs = self.tokenizer(text)
|
||||||
|
|
||||||
@ -439,7 +460,7 @@ class InternVLProcessor(BaseInternVLProcessor):
|
|||||||
repl_features = IMG_CONTEXT * feature_size
|
repl_features = IMG_CONTEXT * feature_size
|
||||||
repl_full = IMG_START + repl_features + IMG_END
|
repl_full = IMG_START + repl_features + IMG_END
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
|
|
||||||
|
|
||||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||||
@ -578,6 +599,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_num_patches),
|
"image", image_num_patches),
|
||||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
@ -809,6 +831,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values_flat is None and image_embeds is None:
|
if pixel_values_flat is None and image_embeds is None:
|
||||||
@ -837,14 +860,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of image_num_patches. "
|
raise ValueError("Incorrect type of image_num_patches. "
|
||||||
f"Got type: {type(image_num_patches)}")
|
f"Got type: {type(image_num_patches)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values_flat=self._validate_pixel_values(
|
pixel_values_flat=self._validate_pixel_values(
|
||||||
pixel_values_flat),
|
pixel_values_flat),
|
||||||
num_patches=image_num_patches,
|
num_patches=image_num_patches,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
@ -890,7 +919,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
if image_input["type"] != "pixel_values":
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -904,7 +941,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.img_context_token_id,
|
self.img_context_token_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -32,8 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|||||||
ImageSize, MultiModalDataItems)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, ProcessingCache,
|
BaseProcessingInfo, ProcessingCache,
|
||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate)
|
||||||
PromptUpdateDetails)
|
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -43,7 +42,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
|||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||||
|
select_patch_features)
|
||||||
|
|
||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
@ -67,6 +67,14 @@ class PixtralHFImagePixelInputs(TypedDict):
|
|||||||
in which case the data is passed as a list instead of a batched tensor.
|
in which case the data is passed as a list instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LlavaImageEmbeddingInputs(TypedDict):
|
class LlavaImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -335,6 +343,23 @@ class PixtralHFMultiModalProcessor(
|
|||||||
for p, (h, w) in zip(pixel_values, image_sizes)
|
for p, (h, w) in zip(pixel_values, image_sizes)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
assert isinstance(vision_config, PixtralVisionConfig)
|
||||||
|
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||||
|
|
||||||
|
tile_sizes = [
|
||||||
|
encoder_info.get_patch_grid_size(
|
||||||
|
image_width=pixel_value.shape[-1],
|
||||||
|
image_height=pixel_value.shape[-2],
|
||||||
|
) for pixel_value in processed_outputs["pixel_values"]
|
||||||
|
]
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
|
for ncols, nrows in tile_sizes
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
@ -344,6 +369,7 @@ class PixtralHFMultiModalProcessor(
|
|||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -378,7 +404,7 @@ class PixtralHFMultiModalProcessor(
|
|||||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||||
tokens[-1] = image_end_id
|
tokens[-1] = image_end_id
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
return tokens
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -586,9 +612,17 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
if self.config.vision_config.model_type == "pixtral":
|
if self.config.vision_config.model_type == "pixtral":
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return PixtralHFImagePixelInputs(
|
return PixtralHFImagePixelInputs(
|
||||||
type="pixel_values_pixtral",
|
type="pixel_values_pixtral",
|
||||||
pixel_values=flatten_bn(pixel_values),
|
pixel_values=flatten_bn(pixel_values),
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
@ -680,7 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
if image_input["type"] != "pixel_values_pixtral":
|
||||||
|
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -692,7 +735,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.config.image_token_index,
|
self.config.image_token_index,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -40,8 +40,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
|||||||
DictEmbeddingItems, ModalityData,
|
DictEmbeddingItems, ModalityData,
|
||||||
ModalityDataItems, MultiModalDataItems,
|
ModalityDataItems, MultiModalDataItems,
|
||||||
MultiModalDataParser)
|
MultiModalDataParser)
|
||||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||||
PromptUpdateDetails)
|
|
||||||
from vllm.multimodal.profiling import ProcessorInputs
|
from vllm.multimodal.profiling import ProcessorInputs
|
||||||
|
|
||||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||||
@ -51,6 +50,7 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
|||||||
_minicpmv_field_config)
|
_minicpmv_field_config)
|
||||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
from .vision import scatter_patch_features
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
@ -73,6 +73,14 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
|||||||
which equals to `audio_features.shape[-1]`
|
which equals to `audio_features.shape[-1]`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which audio embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_audios, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||||
type: Literal["audio_embeds"]
|
type: Literal["audio_embeds"]
|
||||||
@ -85,6 +93,14 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
|||||||
Length of each slice may vary, so pass it as a list.
|
Length of each slice may vary, so pass it as a list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which audio embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_audios, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||||
MiniCPMOAudioEmbeddingInputs]
|
MiniCPMOAudioEmbeddingInputs]
|
||||||
@ -99,6 +115,7 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
|||||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||||
|
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
|
||||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -180,7 +197,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
|||||||
pool_step = self.get_default_audio_pool_step()
|
pool_step = self.get_default_audio_pool_step()
|
||||||
fbank_feat_in_chunk = 100
|
fbank_feat_in_chunk = 100
|
||||||
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
|
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
|
||||||
return (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
||||||
|
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
|
||||||
|
|
||||||
def get_max_audio_chunks_with_most_features(self) -> int:
|
def get_max_audio_chunks_with_most_features(self) -> int:
|
||||||
return 30
|
return 30
|
||||||
@ -191,7 +209,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
|||||||
|
|
||||||
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
|
||||||
sampling_rate = self.get_default_audio_sampling_rate()
|
sampling_rate = self.get_default_audio_sampling_rate()
|
||||||
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
|
# exclude <audio> </audio>
|
||||||
|
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
|
||||||
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
|
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
|
||||||
|
|
||||||
def get_num_frames_with_most_features(
|
def get_num_frames_with_most_features(
|
||||||
@ -276,6 +295,13 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
|
|
||||||
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||||
audio_inputs = {}
|
audio_inputs = {}
|
||||||
|
|
||||||
|
audio_lens = [
|
||||||
|
self.info.get_audio_len_by_num_chunks(
|
||||||
|
sum(map(len,
|
||||||
|
parsed_audios.get(i)["audio_embeds"])))
|
||||||
|
for i in range(len(parsed_audios))
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
audio_inputs = self._base_call_hf_processor(
|
audio_inputs = self._base_call_hf_processor(
|
||||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||||
@ -297,7 +323,27 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
]
|
]
|
||||||
audio_inputs["audio_features"] = unpadded_audio_features
|
audio_inputs["audio_features"] = unpadded_audio_features
|
||||||
|
|
||||||
|
audio_lens = [
|
||||||
|
parsed_audios.get_audio_length(i)
|
||||||
|
for i in range(len(parsed_audios))
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_repl_features = [
|
||||||
|
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
|
||||||
|
]
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
audio_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(audio_repl, add_special_tokens=False)
|
||||||
|
for audio_repl in audio_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
self.get_embed_is_patch(audio_repl_tokens)
|
||||||
|
for audio_repl_tokens in audio_repls_feature_tokens
|
||||||
|
]
|
||||||
|
audio_inputs["audio_embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
@ -338,10 +384,7 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
else:
|
else:
|
||||||
audio_len = audios.get_audio_length(item_idx)
|
audio_len = audios.get_audio_length(item_idx)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(
|
return self.get_audio_prompt_texts(audio_len)
|
||||||
self.get_audio_prompt_texts(audio_len),
|
|
||||||
"<unk>",
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
*base_updates,
|
*base_updates,
|
||||||
@ -670,6 +713,13 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
assert isinstance(audio_token_id, torch.Tensor)
|
assert isinstance(audio_token_id, torch.Tensor)
|
||||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||||
|
|
||||||
|
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
|
||||||
|
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio_embed_is_patch. "
|
||||||
|
f"Got type: {type(audio_embed_is_patch)}")
|
||||||
|
|
||||||
|
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
|
||||||
|
|
||||||
if audio_embeds is not None:
|
if audio_embeds is not None:
|
||||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio_embeds. "
|
raise ValueError("Incorrect type of audio_embeds. "
|
||||||
@ -680,6 +730,7 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
return MiniCPMOAudioEmbeddingInputs(
|
return MiniCPMOAudioEmbeddingInputs(
|
||||||
type="audio_embeds",
|
type="audio_embeds",
|
||||||
audio_embeds=audio_embeds_flat,
|
audio_embeds=audio_embeds_flat,
|
||||||
|
embed_is_patch=audio_embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||||
@ -698,6 +749,7 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
type="audio_features",
|
type="audio_features",
|
||||||
audio_features=audio_features_flat,
|
audio_features=audio_features_flat,
|
||||||
audio_feature_lens=audio_feature_lens_flat,
|
audio_feature_lens=audio_feature_lens_flat,
|
||||||
|
embed_is_patch=audio_embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
@ -729,6 +781,10 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
if modality == "audios":
|
if modality == "audios":
|
||||||
audio_input = modalities["audios"]
|
audio_input = modalities["audios"]
|
||||||
audio_features = self._process_audio_input(audio_input)
|
audio_features = self._process_audio_input(audio_input)
|
||||||
multimodal_embeddings += tuple(audio_features)
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
audio_features,
|
||||||
|
audio_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
@ -56,7 +56,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
|||||||
VideoItem, VideoProcessorItems)
|
VideoItem, VideoProcessorItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -67,6 +67,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
# For profile run
|
# For profile run
|
||||||
_MAX_FRAMES_PER_VIDEO = 16
|
_MAX_FRAMES_PER_VIDEO = 16
|
||||||
@ -89,6 +90,14 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
|||||||
This should be in `(height, width)` format.
|
This should be in `(height, width)` format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
num_slices: torch.Tensor
|
num_slices: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
@ -103,6 +112,14 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
|
|||||||
instead of a batched tensor.
|
instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
||||||
MiniCPMVImageEmbeddingInputs]
|
MiniCPMVImageEmbeddingInputs]
|
||||||
@ -228,10 +245,12 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
|||||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||||
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||||
|
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
|
||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||||
)
|
)
|
||||||
@ -379,43 +398,22 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
use_image_id=use_image_id,
|
use_image_id=use_image_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sliced_grid(
|
|
||||||
self,
|
|
||||||
image_size: ImageSize,
|
|
||||||
# For MiniCPM V/O 2.6
|
|
||||||
max_slice_nums: Optional[int] = None,
|
|
||||||
) -> Optional[tuple[int, int]]:
|
|
||||||
image_processor = self.get_image_processor()
|
|
||||||
version = self.get_model_version()
|
|
||||||
|
|
||||||
if version == (2, 0) or version == (2, 5):
|
|
||||||
return image_processor.get_sliced_grid(image_size)
|
|
||||||
|
|
||||||
if max_slice_nums is None:
|
|
||||||
max_slice_nums = image_processor.max_slice_nums
|
|
||||||
|
|
||||||
return image_processor.get_sliced_grid(
|
|
||||||
image_size,
|
|
||||||
max_slice_nums=max_slice_nums,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_image_tokens(
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
image_size: ImageSize,
|
image_size: ImageSize,
|
||||||
max_slice_nums: Optional[int] = None,
|
max_slice_nums: Optional[int] = None,
|
||||||
|
use_image_id: bool = True,
|
||||||
) -> int:
|
) -> int:
|
||||||
image_processor = self.get_image_processor()
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_placeholders = self.get_slice_image_placeholder(
|
||||||
grid = self.get_sliced_grid(
|
|
||||||
image_size,
|
image_size,
|
||||||
max_slice_nums=max_slice_nums,
|
max_slice_nums=max_slice_nums,
|
||||||
|
use_image_id=use_image_id,
|
||||||
)
|
)
|
||||||
if grid is None:
|
image_token_ids = tokenizer.encode(image_placeholders,
|
||||||
ncols = nrows = 0
|
add_special_tokens=False)
|
||||||
else:
|
|
||||||
ncols, nrows = grid
|
|
||||||
|
|
||||||
return (ncols * nrows + 1) * image_processor.image_feature_size
|
return len(image_token_ids)
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
image_size = self.get_image_size_with_most_features()
|
image_size = self.get_image_size_with_most_features()
|
||||||
@ -435,6 +433,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
return self.get_num_image_tokens(
|
return self.get_num_image_tokens(
|
||||||
frame_size,
|
frame_size,
|
||||||
max_slice_nums=self.get_video_max_slice_num(),
|
max_slice_nums=self.get_video_max_slice_num(),
|
||||||
|
use_image_id=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_max_video_tokens(
|
def get_max_video_tokens(
|
||||||
@ -540,6 +539,14 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
use_image_id=False,
|
use_image_id=False,
|
||||||
) * num_frames
|
) * num_frames
|
||||||
|
|
||||||
|
def get_embed_is_patch(
|
||||||
|
self,
|
||||||
|
input_ids: list[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
|
return torch.tensor(input_ids) == unk_token_id
|
||||||
|
|
||||||
def process_images(
|
def process_images(
|
||||||
self,
|
self,
|
||||||
mm_data: Mapping[str, object],
|
mm_data: Mapping[str, object],
|
||||||
@ -563,7 +570,26 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
image_sizes = [
|
||||||
|
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||||
|
]
|
||||||
|
image_repl_features = [
|
||||||
|
self.get_image_prompt_texts(size, idx)
|
||||||
|
for idx, size in enumerate(image_sizes)
|
||||||
|
]
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||||
|
for image_repl in image_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
self.get_embed_is_patch(image_repl_tokens)
|
||||||
|
for image_repl_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
@ -599,9 +625,31 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||||
)
|
)
|
||||||
|
|
||||||
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
frame_sizes = [
|
||||||
|
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
|
||||||
|
]
|
||||||
|
num_frames = [
|
||||||
|
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
|
||||||
|
]
|
||||||
|
video_repl_features = [
|
||||||
|
self.get_video_prompt_texts(size, nframes)
|
||||||
|
for size, nframes in zip(frame_sizes, num_frames)
|
||||||
|
]
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
video_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(video_repl, add_special_tokens=False)
|
||||||
|
for video_repl in video_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
self.get_embed_is_patch(video_repl_tokens)
|
||||||
|
for video_repl_tokens in video_repls_feature_tokens
|
||||||
|
]
|
||||||
|
video_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
|
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
||||||
|
|
||||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
@ -692,10 +740,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(
|
return self.get_image_prompt_texts(image_size, item_idx)
|
||||||
self.get_image_prompt_texts(image_size, item_idx),
|
|
||||||
"<unk>",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_video_replacement(item_idx: int):
|
def get_video_replacement(item_idx: int):
|
||||||
videos = mm_items.get_items(
|
videos = mm_items.get_items(
|
||||||
@ -704,10 +749,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
frame_size = videos.get_frame_size(item_idx)
|
frame_size = videos.get_frame_size(item_idx)
|
||||||
num_frames = videos.get_num_frames(item_idx)
|
num_frames = videos.get_num_frames(item_idx)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(
|
return self.get_video_prompt_texts(frame_size, num_frames)
|
||||||
self.get_video_prompt_texts(frame_size, num_frames),
|
|
||||||
"<unk>",
|
|
||||||
)
|
|
||||||
|
|
||||||
get_replacement = {
|
get_replacement = {
|
||||||
"image": get_image_replacement,
|
"image": get_image_replacement,
|
||||||
@ -790,6 +832,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
assert isinstance(image_token_id, torch.Tensor)
|
assert isinstance(image_token_id, torch.Tensor)
|
||||||
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect type of embed_is_patch for {modality=}. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -801,6 +851,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return MiniCPMVImageEmbeddingInputs(
|
return MiniCPMVImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds_flat,
|
image_embeds=image_embeds_flat,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
@ -828,6 +879,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values_flat,
|
pixel_values=pixel_values_flat,
|
||||||
tgt_sizes=tgt_sizes_flat,
|
tgt_sizes=tgt_sizes_flat,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
num_slices=num_slices_flat,
|
num_slices=num_slices_flat,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -884,11 +936,19 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if modality == "images":
|
if modality == "images":
|
||||||
image_input = modalities["images"]
|
image_input = modalities["images"]
|
||||||
image_features = self._process_vision_input(image_input)
|
image_features = self._process_vision_input(image_input)
|
||||||
multimodal_embeddings += tuple(image_features)
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
))
|
||||||
if modality == "videos":
|
if modality == "videos":
|
||||||
video_input = modalities["videos"]
|
video_input = modalities["videos"]
|
||||||
video_features = self._process_vision_input(video_input)
|
video_features = self._process_vision_input(video_input)
|
||||||
multimodal_embeddings += tuple(video_features)
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
video_features,
|
||||||
|
video_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
@ -911,7 +971,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
list(self.mm_token_ids),
|
list(self.mm_token_ids),
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -27,8 +27,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, ProcessingCache,
|
BaseProcessingInfo, ProcessingCache,
|
||||||
PromptReplacement, PromptUpdate,
|
PromptReplacement, PromptUpdate)
|
||||||
PromptUpdateDetails)
|
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -36,7 +35,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|||||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||||
|
select_patch_features)
|
||||||
|
|
||||||
|
|
||||||
class Mistral3ImagePixelInputs(TypedDict):
|
class Mistral3ImagePixelInputs(TypedDict):
|
||||||
@ -49,6 +49,14 @@ class Mistral3ImagePixelInputs(TypedDict):
|
|||||||
in which case the data is passed as a list instead of a batched tensor.
|
in which case the data is passed as a list instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Mistral3PatchMerger(nn.Module):
|
class Mistral3PatchMerger(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -258,6 +266,23 @@ class Mistral3MultiModalProcessor(
|
|||||||
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
assert isinstance(vision_config, PixtralVisionConfig)
|
||||||
|
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||||
|
|
||||||
|
tile_sizes = [
|
||||||
|
encoder_info.get_patch_grid_size(
|
||||||
|
image_width=pixel_value.shape[-1],
|
||||||
|
image_height=pixel_value.shape[-2],
|
||||||
|
) for pixel_value in processed_outputs["pixel_values"]
|
||||||
|
]
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
|
for ncols, nrows in tile_sizes
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
@ -267,6 +292,7 @@ class Mistral3MultiModalProcessor(
|
|||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,7 +327,7 @@ class Mistral3MultiModalProcessor(
|
|||||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||||
tokens[-1] = image_end_id
|
tokens[-1] = image_end_id
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
return tokens
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -392,6 +418,8 @@ def init_vision_tower_for_llava(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(mgoin): Support V1, there are issues with image batching/chunking
|
||||||
|
# that need to be resolved first.
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
_build_mistral3_processor,
|
_build_mistral3_processor,
|
||||||
info=_build_mistral3_info,
|
info=_build_mistral3_info,
|
||||||
@ -481,9 +509,16 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
assert self.config.vision_config.model_type == "pixtral"
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
return Mistral3ImagePixelInputs(
|
return Mistral3ImagePixelInputs(
|
||||||
type="pixel_values_pixtral",
|
type="pixel_values_pixtral",
|
||||||
pixel_values=flatten_bn(pixel_values),
|
pixel_values=flatten_bn(pixel_values),
|
||||||
|
embed_is_patch=flatten_bn(embed_is_patch),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
@ -522,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
|
||||||
return vision_embeddings
|
return scatter_patch_features(
|
||||||
|
vision_embeddings,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -534,7 +572,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.config.image_token_index,
|
self.config.image_token_index,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -46,8 +46,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptIndexTargets,
|
BaseProcessingInfo, PromptIndexTargets,
|
||||||
PromptInsertion, PromptUpdate,
|
PromptInsertion, PromptUpdate)
|
||||||
PromptUpdateDetails)
|
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -57,6 +56,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
# TODO: hard-coded for now. Consider making it configurable.
|
# TODO: hard-coded for now. Consider making it configurable.
|
||||||
VIT_LAYERS = [-2, -9]
|
VIT_LAYERS = [-2, -9]
|
||||||
@ -84,6 +84,14 @@ class MolmoImageInputs(TypedDict):
|
|||||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
num_crops: torch.Tensor
|
num_crops: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
@ -1138,6 +1146,30 @@ class MolmoProcessorWrapper:
|
|||||||
if image_input_idx is not None:
|
if image_input_idx is not None:
|
||||||
feat_is_patch = image_input_idx >= 0
|
feat_is_patch = image_input_idx >= 0
|
||||||
|
|
||||||
|
input_is_embed = torch.isin(
|
||||||
|
input_ids,
|
||||||
|
torch.tensor([
|
||||||
|
self.image_patch_id,
|
||||||
|
self.im_col_id,
|
||||||
|
self.im_start_id,
|
||||||
|
self.im_end_id,
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
embed_ids = input_ids[input_is_embed]
|
||||||
|
embed_is_patch = embed_ids == self.image_patch_id
|
||||||
|
assert embed_is_patch.sum() == feat_is_patch.sum()
|
||||||
|
|
||||||
|
# image_tokens = extra_joint + joint
|
||||||
|
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
||||||
|
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
||||||
|
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
||||||
|
assert len(embed_start) == len(embed_end) == len(images)
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
embed_is_patch[start:end + 1]
|
||||||
|
for start, end in zip(embed_start, embed_end)
|
||||||
|
]
|
||||||
|
|
||||||
tilings = [
|
tilings = [
|
||||||
self.select_tiling(
|
self.select_tiling(
|
||||||
image_width=image.size[0],
|
image_width=image.size[0],
|
||||||
@ -1149,6 +1181,7 @@ class MolmoProcessorWrapper:
|
|||||||
assert num_crops.sum() == len(feat_is_patch)
|
assert num_crops.sum() == len(feat_is_patch)
|
||||||
|
|
||||||
outputs["feat_is_patch"] = feat_is_patch
|
outputs["feat_is_patch"] = feat_is_patch
|
||||||
|
outputs["embed_is_patch"] = embed_is_patch
|
||||||
outputs["num_crops"] = num_crops
|
outputs["num_crops"] = num_crops
|
||||||
outputs["img_patch_id"] = self.image_patch_id
|
outputs["img_patch_id"] = self.image_patch_id
|
||||||
|
|
||||||
@ -1187,13 +1220,17 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
|||||||
)
|
)
|
||||||
pooling_size = processor.pooling_size
|
pooling_size = processor.pooling_size
|
||||||
|
|
||||||
image_token_length_w = processor.image_token_length_w
|
base_image_input_size = processor.base_image_input_size
|
||||||
image_token_length_h = processor.image_token_length_h
|
base_image_input_d = processor.image_patch_size
|
||||||
|
|
||||||
extra = image_token_length_w * image_token_length_h
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||||
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
|
|
||||||
|
|
||||||
return extra + joint
|
per_row = ncols // pooling_size + 1
|
||||||
|
joint = per_row * (nrows // pooling_size) + 2
|
||||||
|
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
|
||||||
|
resize = (image_token_length + 1) * image_token_length + 2
|
||||||
|
|
||||||
|
return resize + joint
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
@ -1291,6 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
|||||||
"image", num_crops),
|
"image", num_crops),
|
||||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_crops),
|
"image", num_crops),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
num_crops=MultiModalFieldConfig.batched("image"),
|
||||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
@ -1330,10 +1368,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
|||||||
joint = ([img_start_id] + joint_row *
|
joint = ([img_start_id] + joint_row *
|
||||||
((nrows + 1) // pooling_size) + [img_end_id])
|
((nrows + 1) // pooling_size) + [img_end_id])
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
image_tokens = extra_joint + joint
|
||||||
extra_joint + joint,
|
return image_tokens
|
||||||
embed_token_id=img_patch_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptInsertion(
|
PromptInsertion(
|
||||||
@ -1439,6 +1475,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
raise ValueError("Incorrect type of feat_is_patch. "
|
raise ValueError("Incorrect type of feat_is_patch. "
|
||||||
f"Got type: {type(feat_is_patch)}")
|
f"Got type: {type(feat_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_crops. "
|
raise ValueError("Incorrect type of num_crops. "
|
||||||
@ -1450,12 +1491,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
f"Got type: {type(img_patch_id)}")
|
f"Got type: {type(img_patch_id)}")
|
||||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||||
|
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
num_crops = flatten_bn(num_crops, concat=True)
|
num_crops = flatten_bn(num_crops, concat=True)
|
||||||
|
|
||||||
return MolmoImageInputs(
|
return MolmoImageInputs(
|
||||||
images=images,
|
images=images,
|
||||||
image_masks=image_masks,
|
image_masks=image_masks,
|
||||||
feat_is_patch=feat_is_patch,
|
feat_is_patch=feat_is_patch,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
num_crops=num_crops,
|
num_crops=num_crops,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1494,7 +1537,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -1508,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.img_patch_id,
|
self.img_patch_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
|
|||||||
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
||||||
repl = "<Image>" + features + "</Image>"
|
repl = "<Image>" + features + "</Image>"
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl, IMG_PAD)
|
return PromptUpdateDetails(full=repl, features=repl)
|
||||||
|
|
||||||
|
|
||||||
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||||
@ -84,6 +84,31 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
tokenizer = hf_processor.tokenizer
|
||||||
|
|
||||||
|
max_num_patches = hf_processor.max_dynamic_patch
|
||||||
|
# we need +1 here because max_dynamic_patch in config doesn't
|
||||||
|
# include the thumbnail patch
|
||||||
|
tile_pos_identifiers = [
|
||||||
|
f"<tile_{i+1}>" for i in range(max_num_patches)
|
||||||
|
]
|
||||||
|
if hf_processor.use_thumbnail and max_num_patches != 1:
|
||||||
|
tile_pos_identifiers += ["<tile_global_thumbnail>"]
|
||||||
|
|
||||||
|
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
|
||||||
|
# so we include <tile_1> in the start_str
|
||||||
|
start_str = "<Image>" + tile_pos_identifiers.pop(0)
|
||||||
|
end_str = "</Image>"
|
||||||
|
start_token_len = len(tokenizer.encode(start_str))
|
||||||
|
end_token_len = len(tokenizer.encode(end_str))
|
||||||
|
tile_token_len = sum(
|
||||||
|
len(tokenizer.encode(identifier))
|
||||||
|
for identifier in tile_pos_identifiers)
|
||||||
|
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
|
||||||
|
return super().get_max_image_tokens() + non_image_tokens_num
|
||||||
|
|
||||||
|
|
||||||
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||||
|
|
||||||
@ -152,7 +177,10 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
|||||||
|
|
||||||
repl = hf_processor.get_image_repl(feature_size, num_patches)
|
repl = hf_processor.get_image_repl(feature_size, num_patches)
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
|
return PromptUpdateDetails(
|
||||||
|
full=repl.full + "\n",
|
||||||
|
features=repl.features + "\n",
|
||||||
|
)
|
||||||
|
|
||||||
# See note in dummy data regarding why we have the extra newline
|
# See note in dummy data regarding why we have the extra newline
|
||||||
return [
|
return [
|
||||||
|
@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
|
|||||||
modality="image",
|
modality="image",
|
||||||
target=PromptIndexTargets.prefix(
|
target=PromptIndexTargets.prefix(
|
||||||
[bos_token_id] if tokenizer.add_bos_token else []),
|
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||||
insertion=PromptUpdateDetails.select_token_id(
|
insertion=PromptUpdateDetails(
|
||||||
image_tokens + [bos_token_id],
|
full=image_tokens + [bos_token_id],
|
||||||
embed_token_id=image_token_id,
|
features=image_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -40,7 +40,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BoundPromptUpdate,
|
BaseProcessingInfo, BoundPromptUpdate,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptUpdate)
|
PromptReplacement, PromptUpdate,
|
||||||
|
PromptUpdateDetails)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -442,7 +443,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
processor=hf_processor,
|
processor=hf_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [_IMAGE_TOKEN_ID] * num_image_tokens
|
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||||
|
|
||||||
|
return PromptUpdateDetails(
|
||||||
|
full=image_tokens,
|
||||||
|
features=image_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
num_images = mm_items.get_count("image", strict=False)
|
num_images = mm_items.get_count("image", strict=False)
|
||||||
|
|
||||||
@ -511,7 +517,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|||||||
item_idx=p.item_idx,
|
item_idx=p.item_idx,
|
||||||
start_idx=p.start_idx - 1,
|
start_idx=p.start_idx - 1,
|
||||||
tokens=p.tokens,
|
tokens=p.tokens,
|
||||||
is_embed=p.is_embed,
|
|
||||||
) for p in ps
|
) for p in ps
|
||||||
]
|
]
|
||||||
for modality, ps in placeholders.items()
|
for modality, ps in placeholders.items()
|
||||||
|
@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
|
||||||
|
scatter_patch_features, select_patch_features)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
|
|||||||
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
|
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class PixtralProcessorAdapter:
|
class PixtralProcessorAdapter:
|
||||||
"""
|
"""
|
||||||
@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
|
|||||||
"For more info, see: "
|
"For more info, see: "
|
||||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||||
|
|
||||||
|
image_token_id = self.image_token_id
|
||||||
|
|
||||||
images_processed = list[torch.Tensor]()
|
images_processed = list[torch.Tensor]()
|
||||||
images_tokens = list[torch.Tensor]()
|
images_tokens = list[torch.Tensor]()
|
||||||
|
images_embed_is_patch = list[torch.Tensor]()
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||||
@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
|
|||||||
|
|
||||||
images_processed.append(image_processed)
|
images_processed.append(image_processed)
|
||||||
images_tokens.append(image_tokens)
|
images_tokens.append(image_tokens)
|
||||||
|
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||||
"images": images_processed,
|
"images": images_processed,
|
||||||
|
"embed_is_patch": images_embed_is_patch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
|||||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||||
Image.new("RGB", (image_width, image_height)))
|
Image.new("RGB", (image_width, image_height)))
|
||||||
|
|
||||||
return ncols * nrows
|
return (ncols + 1) * nrows
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
image_processor = self.get_hf_processor().image_processor
|
image_processor = self.get_hf_processor().image_processor
|
||||||
@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
|||||||
hf_inputs: Mapping[str, NestedTensors],
|
hf_inputs: Mapping[str, NestedTensors],
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(images=MultiModalFieldConfig.batched("image"))
|
return dict(
|
||||||
|
images=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
|||||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||||
tokens[-1] = image_end_id
|
tokens[-1] = image_end_id
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
return tokens
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError("Incorrect type of images. "
|
raise ValueError("Incorrect type of images. "
|
||||||
f"Got type: {type(images)}")
|
f"Got type: {type(images)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return PixtralImagePixelInputs(
|
return PixtralImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
images=flatten_bn(images),
|
images=flatten_bn(images),
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.vision_args.image_token_id,
|
self.vision_args.image_token_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
|||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
image_height=image_height,
|
image_height=image_height,
|
||||||
)
|
)
|
||||||
return ncols * nrows
|
|
||||||
|
# Consider the image_break_token
|
||||||
|
return (ncols + 1) * nrows
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
image_size = self.get_image_size()
|
image_size = self.get_image_size()
|
||||||
|
@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
|
|||||||
|
|
||||||
audio_tokens = [audio_token_id] * num_features
|
audio_tokens = [audio_token_id] * num_features
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
return PromptUpdateDetails(
|
||||||
[audio_bos_id] + audio_tokens + [audio_eos_id],
|
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||||
embed_token_id=audio_token_id,
|
features=audio_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
|||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
target=[img_start_id, img_end_id],
|
target=[img_start_id, img_end_id],
|
||||||
replacement=PromptUpdateDetails.select_token_id(
|
replacement=PromptUpdateDetails(
|
||||||
[img_start_id] + image_tokens + [img_end_id],
|
full=[img_start_id] + image_tokens + [img_end_id],
|
||||||
embed_token_id=img_pad_id,
|
features=image_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@ -40,6 +40,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -60,6 +61,14 @@ class SkyworkR1VImagePixelInputs(TypedDict):
|
|||||||
num_patches: torch.Tensor
|
num_patches: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SkyworkR1VImageEmbeddingInputs(TypedDict):
|
class SkyworkR1VImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -410,13 +419,24 @@ class BaseSkyworkR1VProcessor(ABC):
|
|||||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
image_token_id = self.image_token_id
|
||||||
|
|
||||||
|
embed_is_patch = list[torch.Tensor]()
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst:
|
for pixel_values in pixel_values_lst:
|
||||||
num_patches = pixel_values.shape[0]
|
num_patches = pixel_values.shape[0]
|
||||||
feature_size = num_patches * self.num_image_token
|
feature_size = num_patches * self.num_image_token
|
||||||
|
|
||||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||||
|
feature_tokens = tokenizer.encode(image_repl.features,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||||
|
embed_is_patch.append(
|
||||||
|
torch.tensor(feature_tokens) == image_token_id)
|
||||||
|
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
text_inputs = self.tokenizer(text)
|
text_inputs = self.tokenizer(text)
|
||||||
|
|
||||||
@ -440,7 +460,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
|
|||||||
repl_features = IMG_CONTEXT * feature_size
|
repl_features = IMG_CONTEXT * feature_size
|
||||||
repl_full = IMG_START + repl_features + IMG_END
|
repl_full = IMG_START + repl_features + IMG_END
|
||||||
|
|
||||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
|
|
||||||
|
|
||||||
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||||
@ -579,6 +599,7 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_num_patches),
|
"image", image_num_patches),
|
||||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
@ -814,6 +835,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
|
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
|
||||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values_flat is None and image_embeds is None:
|
if pixel_values_flat is None and image_embeds is None:
|
||||||
@ -842,14 +864,20 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of image_num_patches. "
|
raise ValueError("Incorrect type of image_num_patches. "
|
||||||
f"Got type: {type(image_num_patches)}")
|
f"Got type: {type(image_num_patches)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return SkyworkR1VImagePixelInputs(
|
return SkyworkR1VImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values_flat=self._validate_pixel_values(
|
pixel_values_flat=self._validate_pixel_values(
|
||||||
pixel_values_flat),
|
pixel_values_flat),
|
||||||
num_patches=image_num_patches,
|
num_patches=image_num_patches,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
@ -895,7 +923,15 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._process_image_input(image_input)
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
if image_input["type"] != "pixel_values":
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
return scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -909,7 +945,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multimodal_embeddings,
|
select_patch_features(multimodal_embeddings),
|
||||||
self.img_context_token_id,
|
self.img_context_token_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -9,9 +10,12 @@ from transformers import PretrainedConfig
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.selector import (backend_name_to_enum,
|
from vllm.attention.selector import (backend_name_to_enum,
|
||||||
get_global_forced_attn_backend)
|
get_global_forced_attn_backend)
|
||||||
|
from vllm.jsontree import JSONTree, json_map_leaves
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
|
||||||
|
from .interfaces import MultiModalEmbeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||||
@ -151,3 +155,74 @@ def resolve_visual_encoder_outputs(
|
|||||||
if post_layer_norm is not None and uses_last_layer:
|
if post_layer_norm is not None and uses_last_layer:
|
||||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||||
return torch.cat(hs_pool, dim=-1)
|
return torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_patch_features(
|
||||||
|
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||||
|
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Scatter the patch features into a contiguous tensor that corresponds
|
||||||
|
to the embedding tokens defined by the multimodal processor.
|
||||||
|
|
||||||
|
The rest of the values in the tensor are set to NaN so that they
|
||||||
|
can be filtered out by :func`select_patch_features`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patches: The patch features for each image.
|
||||||
|
Shape: `(num_images, <patch_dims>, feature_depth)`
|
||||||
|
embed_is_patch: A boolean mask indicating which image embeddings
|
||||||
|
correspond to patch tokens for each image.
|
||||||
|
Shape: `(num_images, num_embeds)`
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The original code only considers patch tokens as feature
|
||||||
|
tokens, but our processor considers all image-related tokens
|
||||||
|
as feature tokens because the feature tokens need to be
|
||||||
|
consecutive in `input_ids`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
A simplified example for one image:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Embedding tokens (from HF processor):
|
||||||
|
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||||
|
|
||||||
|
embed_is_patch (from HF processor):
|
||||||
|
[ False True True False True True False False ]
|
||||||
|
|
||||||
|
Encoder outputs (from model):
|
||||||
|
[ p1 p2 p3 p4 ]
|
||||||
|
|
||||||
|
The resulting embedding tensor is:
|
||||||
|
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||||
|
"""
|
||||||
|
if len(patches) != len(embed_is_patch):
|
||||||
|
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
|
||||||
|
f"{len(embed_is_patch)=}")
|
||||||
|
|
||||||
|
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
|
||||||
|
embed_one = patches_one.new_full(
|
||||||
|
(e_is_patch.shape[0], patches_one.shape[-1]),
|
||||||
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
embed_one[e_is_patch] = patches_one
|
||||||
|
return embed_one
|
||||||
|
|
||||||
|
return tuple(
|
||||||
|
get_embed_one(patches_one, e_is_patch)
|
||||||
|
for patches_one, e_is_patch in zip(patches, embed_is_patch))
|
||||||
|
|
||||||
|
|
||||||
|
def select_patch_features(
|
||||||
|
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||||
|
"""
|
||||||
|
Given the outputs of :func:`scatter_patch_features`, return only
|
||||||
|
the values that correspond to patch features.
|
||||||
|
"""
|
||||||
|
selected_features = json_map_leaves(
|
||||||
|
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||||
|
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||||
|
)
|
||||||
|
return cast(MultiModalEmbeddings, selected_features)
|
||||||
|
@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
|
|||||||
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
|
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
|
||||||
multi_modal_items):
|
multi_modal_items):
|
||||||
placeholder = range(
|
placeholder = range(
|
||||||
placeholder_dict.offset,
|
placeholder_dict["offset"],
|
||||||
placeholder_dict.offset + placeholder_dict.length,
|
placeholder_dict["offset"] + placeholder_dict["length"],
|
||||||
)
|
)
|
||||||
intersection = range(
|
intersection = range(
|
||||||
max(positions.start, placeholder.start),
|
max(positions.start, placeholder.start),
|
||||||
|
@ -109,8 +109,7 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
class PlaceholderRange(TypedDict):
|
||||||
class PlaceholderRange:
|
|
||||||
"""
|
"""
|
||||||
Placeholder location information for multi-modal data.
|
Placeholder location information for multi-modal data.
|
||||||
|
|
||||||
@ -122,8 +121,8 @@ class PlaceholderRange:
|
|||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
A: PlaceholderRange(offset=0, length=4)
|
A: { "offset": 0, "length": 4 }
|
||||||
B: PlaceholderRange(offset=5, length=4)
|
B: { "offset": 5, "length": 4 }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
offset: int
|
offset: int
|
||||||
@ -132,31 +131,6 @@ class PlaceholderRange:
|
|||||||
length: int
|
length: int
|
||||||
"""The length of the placeholder."""
|
"""The length of the placeholder."""
|
||||||
|
|
||||||
is_embed: Optional[torch.Tensor] = None
|
|
||||||
"""
|
|
||||||
A boolean mask of shape `(length,)` indicating which positions
|
|
||||||
between `offset` and `offset + length` to assign embeddings to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_num_embeds(self) -> int:
|
|
||||||
if self.is_embed is None:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
return int(self.is_embed.sum().item())
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, self.__class__):
|
|
||||||
return False
|
|
||||||
if not (self.offset, self.length) == (other.offset, other.length):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.is_embed is None:
|
|
||||||
return other.is_embed is None
|
|
||||||
if other.is_embed is None:
|
|
||||||
return self.is_embed is None
|
|
||||||
|
|
||||||
return nested_tensors_equal(self.is_embed, other.is_embed)
|
|
||||||
|
|
||||||
|
|
||||||
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
|
||||||
tuple[torch.Tensor, ...]]
|
tuple[torch.Tensor, ...]]
|
||||||
|
@ -108,46 +108,16 @@ class PromptUpdateDetails(Generic[_S]):
|
|||||||
full: _S
|
full: _S
|
||||||
"""The full content."""
|
"""The full content."""
|
||||||
|
|
||||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
|
features: _S
|
||||||
"""
|
"""
|
||||||
Given :attr:`full`, return a boolean mask of shape `(len(full),)`
|
The part of the content that corresponds to feature placeholders;
|
||||||
indicating which positions of `full` to assign embeddings to.
|
this will be replaced by the output of the vision encoder during model
|
||||||
|
inference.
|
||||||
`None` (default) means to assign embeddings to all positions of `full`.
|
|
||||||
|
|
||||||
The embeddings are obtained by calling
|
|
||||||
:class:`SupportsMultiModal.get_multimodal_embeddings`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
|
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
|
||||||
return PromptUpdateDetails(full=seq)
|
return PromptUpdateDetails(full=seq, features=seq)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def select_text(
|
|
||||||
seq: _S,
|
|
||||||
embed_text: str,
|
|
||||||
) -> "PromptUpdateDetails[_S]":
|
|
||||||
|
|
||||||
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
|
|
||||||
embed_token_ids = encode_tokens(full.tokenizer, embed_text)
|
|
||||||
|
|
||||||
return torch.isin(
|
|
||||||
torch.tensor(full.token_ids),
|
|
||||||
torch.tensor(embed_token_ids),
|
|
||||||
)
|
|
||||||
|
|
||||||
return PromptUpdateDetails(full=seq, is_embed=is_embed)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def select_token_id(
|
|
||||||
seq: _S,
|
|
||||||
embed_token_id: int,
|
|
||||||
) -> "PromptUpdateDetails[_S]":
|
|
||||||
return PromptUpdateDetails(
|
|
||||||
full=seq,
|
|
||||||
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
|
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
|
||||||
@ -436,7 +406,7 @@ class _BoundPromptSequence:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class _BoundPromptContent:
|
class _BoundPromptContent:
|
||||||
full: _BoundPromptSequence
|
full: _BoundPromptSequence
|
||||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
|
features: _BoundPromptSequence
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -496,8 +466,10 @@ class BoundPromptUpdate:
|
|||||||
|
|
||||||
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||||
content.full)
|
content.full)
|
||||||
|
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||||
|
content.features)
|
||||||
bound_content = _BoundPromptContent(full=bound_full,
|
bound_content = _BoundPromptContent(full=bound_full,
|
||||||
is_embed=content.is_embed)
|
features=bound_features)
|
||||||
|
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
self._content_cache[cache_key] = bound_content
|
self._content_cache[cache_key] = bound_content
|
||||||
@ -633,19 +605,15 @@ class PlaceholderFeaturesInfo:
|
|||||||
item_idx: int
|
item_idx: int
|
||||||
start_idx: int
|
start_idx: int
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
is_embed: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def length(self) -> int:
|
def length(self) -> int:
|
||||||
return len(self.tokens)
|
return len(self.tokens)
|
||||||
|
|
||||||
def to_range(self) -> PlaceholderRange:
|
def to_range(self) -> PlaceholderRange:
|
||||||
# TODO: Is it worth it to optimize this by stripping the
|
|
||||||
# leading and ending positions where `is_embed=False`?
|
|
||||||
return PlaceholderRange(
|
return PlaceholderRange(
|
||||||
offset=self.start_idx,
|
offset=self.start_idx,
|
||||||
length=self.length,
|
length=self.length,
|
||||||
is_embed=self.is_embed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -838,17 +806,22 @@ def _iter_placeholders(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if prompt[start_idx:end_idx_full] == content_tokens_full:
|
if prompt[start_idx:end_idx_full] == content_tokens_full:
|
||||||
content_is_embed = content.is_embed
|
content_tokens_feat = content.features.token_ids
|
||||||
if content_is_embed is not None:
|
|
||||||
content_is_embed = content_is_embed(content.full)
|
|
||||||
|
|
||||||
yield PlaceholderFeaturesInfo(
|
try:
|
||||||
modality=modality,
|
match = next(
|
||||||
item_idx=item_idx,
|
iter_token_matches(content_tokens_full,
|
||||||
start_idx=start_idx,
|
content_tokens_feat))
|
||||||
tokens=content_tokens_full,
|
yield PlaceholderFeaturesInfo(
|
||||||
is_embed=content_is_embed,
|
modality=modality,
|
||||||
)
|
item_idx=item_idx,
|
||||||
|
start_idx=start_idx + match.start_idx,
|
||||||
|
tokens=content_tokens_feat,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
raise AssertionError(
|
||||||
|
f"{content_tokens_feat=} should be a "
|
||||||
|
f"subsequence of {content_tokens_full=}") from None
|
||||||
|
|
||||||
# Exclude overlapping matches
|
# Exclude overlapping matches
|
||||||
start_idx = end_idx_full
|
start_idx = end_idx_full
|
||||||
|
@ -180,7 +180,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
|
||||||
total_placeholders_by_modality = {
|
total_placeholders_by_modality = {
|
||||||
modality: sum(item.get_num_embeds() for item in placeholders)
|
modality: sum(item["length"] for item in placeholders)
|
||||||
for modality, placeholders in placeholders_by_modality.items()
|
for modality, placeholders in placeholders_by_modality.items()
|
||||||
}
|
}
|
||||||
expected_placeholders_by_modality = {
|
expected_placeholders_by_modality = {
|
||||||
|
@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
|
|||||||
all_items.append((modality, placeholder, hash_value))
|
all_items.append((modality, placeholder, hash_value))
|
||||||
|
|
||||||
# Sort all items by offset
|
# Sort all items by offset
|
||||||
all_items.sort(key=lambda x: x[1].offset)
|
all_items.sort(key=lambda x: x[1]['offset'])
|
||||||
|
|
||||||
# Split into separate lists
|
# Split into separate lists
|
||||||
sorted_modalities = [item[0] for item in all_items]
|
sorted_modalities = [item[0] for item in all_items]
|
||||||
|
@ -310,7 +310,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
|||||||
# Note that we assume mm_positions is sorted by offset.
|
# Note that we assume mm_positions is sorted by offset.
|
||||||
# We do not need to check all mm inputs if the start token index is out of
|
# We do not need to check all mm inputs if the start token index is out of
|
||||||
# range. This usually happens in the late prefill phase and decoding phase.
|
# range. This usually happens in the late prefill phase and decoding phase.
|
||||||
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
|
if mm_positions[-1]["offset"] + mm_positions[-1][
|
||||||
|
"length"] < start_token_idx:
|
||||||
return extra_keys, start_mm_idx
|
return extra_keys, start_mm_idx
|
||||||
|
|
||||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||||
@ -321,8 +322,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
|||||||
curr_mm_idx = start_mm_idx
|
curr_mm_idx = start_mm_idx
|
||||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||||
assert mm_hashes[curr_mm_idx] is not None
|
assert mm_hashes[curr_mm_idx] is not None
|
||||||
offset = mm_positions[curr_mm_idx].offset
|
offset = mm_positions[curr_mm_idx]["offset"]
|
||||||
length = mm_positions[curr_mm_idx].length
|
length = mm_positions[curr_mm_idx]["length"]
|
||||||
if end_token_idx > offset:
|
if end_token_idx > offset:
|
||||||
if start_token_idx > offset + length:
|
if start_token_idx > offset + length:
|
||||||
# This block has passed the current mm input.
|
# This block has passed the current mm input.
|
||||||
|
@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
assert mm_positions is not None
|
assert mm_positions is not None
|
||||||
assert len(mm_positions) > 0
|
assert len(mm_positions) > 0
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info["offset"]
|
||||||
num_encoder_tokens = pos_info.length
|
num_encoder_tokens = pos_info["length"]
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||||
@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
if cached_encoder_input_ids:
|
if cached_encoder_input_ids:
|
||||||
for input_id in list(cached_encoder_input_ids):
|
for input_id in list(cached_encoder_input_ids):
|
||||||
mm_positions = request.mm_positions[input_id]
|
mm_positions = request.mm_positions[input_id]
|
||||||
start_pos = mm_positions.offset
|
start_pos = mm_positions["offset"]
|
||||||
num_tokens = mm_positions.length
|
num_tokens = mm_positions["length"]
|
||||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||||
# The encoder output is already processed and stored
|
# The encoder output is already processed and stored
|
||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
|
@ -121,7 +121,7 @@ class Request:
|
|||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||||
assert input_id < len(self.mm_positions)
|
assert input_id < len(self.mm_positions)
|
||||||
num_tokens = self.mm_positions[input_id].length
|
num_tokens = self.mm_positions[input_id]["length"]
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -19,8 +19,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -44,8 +43,7 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import sanity_check_mm_encoder_outputs
|
||||||
scatter_mm_placeholders)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
@ -831,22 +829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_inputs = list[MultiModalKwargs]()
|
mm_inputs: list[MultiModalKwargs] = []
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
req_input_ids: list[tuple[str, int]] = []
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
for input_id, pos_info in zip(
|
for input_id in encoder_input_ids:
|
||||||
encoder_input_ids,
|
|
||||||
req_state.mm_positions,
|
|
||||||
):
|
|
||||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||||
req_ids_pos.append((req_id, input_id, pos_info))
|
req_input_ids.append((req_id, input_id))
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
@ -882,23 +877,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs.
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||||
req_ids_pos,
|
|
||||||
encoder_outputs,
|
|
||||||
):
|
|
||||||
if req_id not in self.encoder_cache:
|
if req_id not in self.encoder_cache:
|
||||||
self.encoder_cache[req_id] = {}
|
self.encoder_cache[req_id] = {}
|
||||||
|
self.encoder_cache[req_id][input_id] = output
|
||||||
|
|
||||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
def _gather_encoder_outputs(
|
||||||
output,
|
|
||||||
is_embed=pos_info.is_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _gather_mm_embeddings(
|
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
mm_embeds: list[torch.Tensor] = []
|
encoder_outputs: list[torch.Tensor] = []
|
||||||
for req_id in self.input_batch.req_ids:
|
for req_id in self.input_batch.req_ids:
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
@ -906,8 +894,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info["offset"]
|
||||||
num_encoder_tokens = pos_info.length
|
num_encoder_tokens = pos_info["length"]
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens,
|
# [num_computed_tokens,
|
||||||
@ -929,16 +917,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert req_id in self.encoder_cache
|
assert req_id in self.encoder_cache
|
||||||
assert i in self.encoder_cache[req_id]
|
assert i in self.encoder_cache[req_id]
|
||||||
encoder_output = self.encoder_cache[req_id][i]
|
encoder_output = self.encoder_cache[req_id][i]
|
||||||
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
return encoder_outputs
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_placeholders(
|
|
||||||
encoder_output[start_idx:end_idx],
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
mm_embeds.append(mm_embeds_item)
|
|
||||||
return mm_embeds
|
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
@ -1003,10 +983,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_encoder(scheduler_output)
|
||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||||
else:
|
else:
|
||||||
mm_embeds = []
|
encoder_outputs = []
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||||
@ -1028,9 +1008,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||||
if mm_embeds:
|
if encoder_outputs:
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
input_ids, mm_embeds)
|
input_ids, encoder_outputs)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
# TODO(woosuk): Avoid the copy. Optimize.
|
# TODO(woosuk): Avoid the copy. Optimize.
|
||||||
|
@ -19,8 +19,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -37,8 +36,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import sanity_check_mm_encoder_outputs
|
||||||
scatter_mm_placeholders)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -509,47 +507,19 @@ class TPUModelRunner:
|
|||||||
logits_indices = logits_indices.to(self.device)
|
logits_indices = logits_indices.to(self.device)
|
||||||
return attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
|
|
||||||
def _scatter_placeholders(
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
self,
|
|
||||||
embeds: torch.Tensor,
|
|
||||||
is_embed: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if is_embed is None:
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
placeholders = embeds.new_full(
|
|
||||||
(is_embed.shape[0], embeds.shape[-1]),
|
|
||||||
fill_value=torch.nan,
|
|
||||||
)
|
|
||||||
placeholders[is_embed] = embeds
|
|
||||||
return placeholders
|
|
||||||
|
|
||||||
def _gather_placeholders(
|
|
||||||
self,
|
|
||||||
placeholders: torch.Tensor,
|
|
||||||
is_embed: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if is_embed is None:
|
|
||||||
return placeholders
|
|
||||||
|
|
||||||
return placeholders[is_embed]
|
|
||||||
|
|
||||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_inputs = list[MultiModalKwargs]()
|
mm_inputs: list[MultiModalKwargs] = []
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
req_input_ids: list[tuple[str, int]] = []
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
for input_id, pos_info in zip(
|
for input_id in encoder_input_ids:
|
||||||
encoder_input_ids,
|
|
||||||
req_state.mm_positions,
|
|
||||||
):
|
|
||||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||||
req_ids_pos.append((req_id, input_id, pos_info))
|
req_input_ids.append((req_id, input_id))
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
@ -585,23 +555,16 @@ class TPUModelRunner:
|
|||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs.
|
||||||
for (req_id, input_id, pos_info), output in zip(
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||||
req_ids_pos,
|
|
||||||
encoder_outputs,
|
|
||||||
):
|
|
||||||
if req_id not in self.encoder_cache:
|
if req_id not in self.encoder_cache:
|
||||||
self.encoder_cache[req_id] = {}
|
self.encoder_cache[req_id] = {}
|
||||||
|
self.encoder_cache[req_id][input_id] = output
|
||||||
|
|
||||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
def _gather_encoder_outputs(
|
||||||
output,
|
|
||||||
is_embed=pos_info.is_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _gather_mm_embeddings(
|
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
mm_embeds: list[torch.Tensor] = []
|
encoder_outputs: list[torch.Tensor] = []
|
||||||
for req_id in self.input_batch.req_ids:
|
for req_id in self.input_batch.req_ids:
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
@ -609,8 +572,8 @@ class TPUModelRunner:
|
|||||||
num_computed_tokens = req_state.num_computed_tokens
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
mm_positions = req_state.mm_positions
|
mm_positions = req_state.mm_positions
|
||||||
for i, pos_info in enumerate(mm_positions):
|
for i, pos_info in enumerate(mm_positions):
|
||||||
start_pos = pos_info.offset
|
start_pos = pos_info["offset"]
|
||||||
num_encoder_tokens = pos_info.length
|
num_encoder_tokens = pos_info["length"]
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens,
|
# [num_computed_tokens,
|
||||||
@ -632,16 +595,8 @@ class TPUModelRunner:
|
|||||||
assert req_id in self.encoder_cache
|
assert req_id in self.encoder_cache
|
||||||
assert i in self.encoder_cache[req_id]
|
assert i in self.encoder_cache[req_id]
|
||||||
encoder_output = self.encoder_cache[req_id][i]
|
encoder_output = self.encoder_cache[req_id][i]
|
||||||
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
return encoder_outputs
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_placeholders(
|
|
||||||
encoder_output[start_idx:end_idx],
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
mm_embeds.append(mm_embeds_item)
|
|
||||||
return mm_embeds
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -657,10 +612,10 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_encoder(scheduler_output)
|
||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||||
else:
|
else:
|
||||||
mm_embeds = []
|
encoder_outputs = []
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||||
@ -668,9 +623,9 @@ class TPUModelRunner:
|
|||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
if mm_embeds:
|
if encoder_outputs:
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
self.input_ids, mm_embeds)
|
self.input_ids, encoder_outputs)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -29,46 +27,3 @@ def sanity_check_mm_encoder_outputs(
|
|||||||
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||||
"instead. This is most likely due to incorrect implementation "
|
"instead. This is most likely due to incorrect implementation "
|
||||||
"of the model's `get_multimodal_embeddings` method.")
|
"of the model's `get_multimodal_embeddings` method.")
|
||||||
|
|
||||||
|
|
||||||
def scatter_mm_placeholders(
|
|
||||||
embeds: torch.Tensor,
|
|
||||||
is_embed: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Scatter the multimodal embeddings into a contiguous tensor that represents
|
|
||||||
the placeholder tokens.
|
|
||||||
|
|
||||||
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeds: The multimodal embeddings.
|
|
||||||
Shape: `(num_embeds, embed_dim)`
|
|
||||||
is_embed: A boolean mask indicating which positions in the placeholder
|
|
||||||
tokens need to be filled with multimodal embeddings.
|
|
||||||
Shape: `(num_placeholders, num_embeds)`
|
|
||||||
"""
|
|
||||||
if is_embed is None:
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
placeholders = embeds.new_full(
|
|
||||||
(is_embed.shape[0], embeds.shape[-1]),
|
|
||||||
fill_value=torch.nan,
|
|
||||||
)
|
|
||||||
placeholders[is_embed] = embeds
|
|
||||||
return placeholders
|
|
||||||
|
|
||||||
|
|
||||||
def gather_mm_placeholders(
|
|
||||||
placeholders: torch.Tensor,
|
|
||||||
is_embed: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Reconstructs the embeddings from the placeholder tokens.
|
|
||||||
|
|
||||||
This is the operation of :func:`scatter_mm_placeholders`.
|
|
||||||
"""
|
|
||||||
if is_embed is None:
|
|
||||||
return placeholders
|
|
||||||
|
|
||||||
return placeholders[is_embed]
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user