[VLM] Merged multi-modal processor for Pixtral (#12211)

Signed-off-by: remi <remi@mistral.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Rémi Delacourt 2025-03-15 14:28:27 +01:00 committed by GitHub
parent 74bc397b0a
commit 61c6a5a796
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 620 additions and 358 deletions

View File

@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
# python demo.py advanced # python demo.py advanced
def run_simple_demo(): def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192) sampling_params = SamplingParams(max_tokens=8192)
# Lower max_num_seqs or max_model_len on low-VRAM GPUs. # Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral") llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe this image in one sentence." prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300" image_url = "https://picsum.photos/id/237/200/300"
@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text) print(outputs[0].outputs[0].text)
def run_advanced_demo(): def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5 max_img_per_msg = 5
max_tokens_per_img = 4096 max_tokens_per_img = 4096
@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral", tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg}, limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img, max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
prompt = "Describe the following image." prompt = "Describe the following image."
@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
args = parser.parse_args() args = parser.parse_args()
if args.mode == "simple": if args.mode == "simple":
print("Running simple demo...") print("Running simple demo...")
run_simple_demo() run_simple_demo(args)
elif args.mode == "advanced": elif args.mode == "advanced":
print("Running advanced demo...") print("Running advanced demo...")
run_advanced_demo() run_advanced_demo(args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,17 +2,23 @@
import copy import copy
from functools import partial from functools import partial
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
import pytest import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.processing import ProcessingCache from vllm.multimodal.inputs import MultiModalInputs
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from ....multimodal.utils import random_audio, random_image, random_video from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
} }
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
mm_data = { mm_data = {
k: k:
@ -115,43 +113,131 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1: elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0] mm_data[k] = mm_data[k][0]
baseline_result = baseline_processor.apply( if isinstance(tokenizer, MistralTokenizer):
prompt, _test_processing_correctness_mistral(
mm_data=mm_data, model_config,
hf_processor_mm_kwargs={}, tokenizer,
) prompt,
cached_result = cached_processor.apply( mm_data,
prompt, baseline_processor,
mm_data=mm_data, cached_processor,
hf_processor_mm_kwargs={}, batch_idx,
) ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply( def _test_processing_correctness_hf(
tokenizer.encode(prompt, **tokenizer_encode_kwargs), model_config: ModelConfig,
mm_data=mm_data, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
hf_processor_mm_kwargs={}, prompt: str,
) mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)
assert _drop_mm_kwargs_keys( baseline_result = baseline_processor.apply(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys( prompt,
baseline_tokenized_result, ignore_mm_keys), ( mm_data=mm_data,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply( assert _inputs_equal(
tokenizer.encode(prompt, **tokenizer_encode_kwargs), baseline_result,
mm_data=mm_data, cached_result,
hf_processor_mm_kwargs={}, ignore_mm_keys,
) ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
assert _drop_mm_kwargs_keys( baseline_tokenized_result = baseline_processor.apply(
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys( token_prompt,
cached_tokenized_result, ignore_mm_keys), ( mm_data=mm_data,
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
# yapf: disable # yapf: disable
@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3", "TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
) )
def _drop_mm_kwargs_keys(result: dict, def _inputs_equal(
ignore_mm_keys: Optional[list[str]] = None) -> dict: a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs']. """Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox. This is mainly to avoid doing exact match of audio_features in ultravox.

View File

@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
A boolean mask indicating which image embeddings correspond A boolean mask indicating which image embeddings correspond
to patch tokens. to patch tokens.
Shape: `(batch_size, num_embeds)` Shape: `(batch_size, num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size, num_images)`"""
@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2], image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"] ) for pixel_value in processed_outputs["pixel_values"]
] ]
num_crops = torch.tensor([(ncols + 1) * nrows num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes]) for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to # Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks. # later use `num_patches` to get per-image masks.
embed_is_patch = torch.tensor( embed_is_patch = [
flatten_2d_lists([([True] * ncols + [False]) * nrows torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes])) for ncols, nrows in tile_sizes
processed_outputs["num_crops"] = num_crops ]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict( return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch") embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)): if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. " raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}") f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops") num_patches = kwargs.pop("num_patches")
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch, embed_is_patch=embed_is_patch,
num_crops=num_crops, num_patches=num_patches,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel, vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel], PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = vision_tower(pixel_values) image_features = vision_tower(pixel_values)
return self._select_image_features( def select_features(leaf: torch.Tensor):
image_features, return self._select_image_features(
strategy=self.config.vision_feature_select_strategy, leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
) )
def _process_image_pixels( def _process_image_pixels(
self, self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds( def _get_mm_embeds(
self, self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d) features: torch.Tensor, # Shape: (num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) num_patches: torch.Tensor, # Shape: (num_images,)
num_crops: torch.Tensor, # Shape: (num_images,) embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,) ) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds """Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor. to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
""" """
# Insert columns of nan values according to `embed_is_patch`. This work
# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but # ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to # `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here. # put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is # FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`. # deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1) num_patches_per_image: list[int] = num_patches.tolist()
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_crops_per_image = num_crops.tolist() embeds_flat = features.new_full(
feats_per_image = expanded_embedding.split(num_crops_per_image) (sum(num_patches_per_image), *features.shape[1:]),
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
embed_dim = expanded_embedding.shape[-1] return embeds_flat.split(num_patches_per_image)
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1) # The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings return vision_embeddings
nested_emb = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"], vision_embeddings,
image_input["num_crops"], image_input["embed_is_patch"]) image_input["num_patches"],
] image_input["embed_is_patch"],
return flatten_2d_lists(nested_emb) ))
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) )
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors, input_ids,
patch_embeddings), inputs_embeds,
self.config.image_token_index) cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds return inputs_embeds
def forward( def forward(

View File

@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
nested_embeds = [ return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
image_features, image_features,
image_input["feat_is_patch"], image_input["feat_is_patch"],
image_input["num_crops"], image_input["num_crops"],
image_input["embed_is_patch"], image_input["embed_is_patch"],
) ))
]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings( def get_input_embeddings(
self, self,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement, PromptInsertion, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index

View File

@ -1,26 +1,28 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from typing import List, Optional, Set, Tuple, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image from PIL import Image
from transformers import PixtralVisionConfig from transformers import PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens) _num_image_tokens as _get_pixtral_hf_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.jsontree import JSONTree, json_map_leaves
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
from vllm.sequence import IntermediateTensors, SequenceData MultiModalDataItems)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix, from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -48,132 +57,275 @@ except ImportError:
USE_XFORMERS_OPS = False USE_XFORMERS_OPS = False
def get_max_pixtral_image_tokens(ctx: InputContext): class PixtralImagePixelInputs(TypedDict):
tokenizer = cached_tokenizer_from_config(ctx.model_config) type: Literal["pixel_values"]
mm_encoder = tokenizer.instruct.mm_encoder
image_config = mm_encoder.mm_config if hasattr( images: Union[torch.Tensor, list[torch.Tensor]]
mm_encoder, "mm_config") else mm_encoder.image_config
max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size
return ((max_image_size // image_patch_size)**2)
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.get_mm_config()
num_images = mm_config.get_limit_per_prompt("image")
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
image_feature_size = len(encoding.tokens)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)
mm_data = {"image": num_images * [image]}
mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalKwargs:
"""Maps the input data to its MultiModalKwargs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing PIL images to be processed
and mapped to `images`.
Returns:
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
""" """
tokenizer = cached_tokenizer_from_config(ctx.model_config) Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
data_list = data if isinstance(data, list) else [data] The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
"""
images = [] embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
image_tokens_list = [] """
for image_data in data_list: A boolean mask indicating which image embeddings correspond
image = ImageChunk(image=image_data) to patch tokens.
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16) Shape: `(batch_size, num_images, num_embeds)`
images.append(image) """
image_tokens_list.append(encoding.tokens)
image_tokens = torch.tensor([ num_patches: Union[torch.Tensor, list[torch.Tensor]]
token_id for image_tokens in image_tokens_list """Shape: `(batch_size, num_images)`"""
for token_id in image_tokens
])
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): class PixtralProcessorAdapter:
multi_modal_data = inputs.get("multi_modal_data") """
if multi_modal_data is None or "image" not in multi_modal_data: Provide a HF-compatible interface for
return inputs :class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
prompt_token_ids = inputs.get("prompt_token_ids") def __init__(self, tokenizer: MistralTokenizer) -> None:
prompt = inputs.get("prompt") super().__init__()
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder self.tokenizer = tokenizer
image_token_id = mm_encoder.special_ids.img
image_break_id = mm_encoder.special_ids.img_break
image_end_id = mm_encoder.special_ids.img_end
if image_token_id not in inputs['prompt_token_ids']: @property
raise ValueError( def image_processor(self) -> ImageEncoder:
f"You've passed {inputs=} without {image_token_id=}" image_encoder = self.tokenizer.instruct.mm_encoder
" Make sure to process your input via mistral_common's" assert isinstance(image_encoder, ImageEncoder)
" tokenizer or pass a chat completion request. For more" return image_encoder
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
# Get precise tracking of placeholder positions @cached_property
placeholder_ranges = [] def image_break_id(self) -> int:
curr_offset = -1 return self.image_processor.special_ids.img_break
curr_length = 0
for i in range(len(prompt_token_ids)): @cached_property
if prompt_token_ids[i] in (image_token_id, image_break_id): def image_token_id(self) -> int:
if curr_offset < 0: return self.image_processor.special_ids.img
curr_offset = i
curr_length += 1 @cached_property
elif prompt_token_ids[i] == image_end_id: def image_end_id(self) -> int:
curr_length += 1 return self.image_processor.special_ids.img_end
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length)) @cached_property
curr_offset = -1 def image_size(self) -> int:
curr_length = 0 return self.image_processor.mm_config.max_image_size
else:
pass @cached_property
return token_inputs(prompt=prompt, def patch_size(self) -> int:
prompt_token_ids=prompt_token_ids, return self.image_processor.mm_config.image_patch_size
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges}) def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
image_token_id = self.image_token_id
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_patches = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_patches.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_patches": torch.tensor(images_num_patches),
}
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) class PixtralProcessingInfo(BaseProcessingInfo):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) def get_tokenizer(self) -> MistralTokenizer:
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config(
self,
processor: Optional[PixtralProcessorAdapter] = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[PixtralProcessorAdapter] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)))
return (ncols + 1) * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_break_id = processor.image_break_id
image_token_id = processor.image_token_id
image_end_id = processor.image_end_id
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_size.width, image_size.height)))
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return tokens
return [
PromptReplacement(
modality="image",
target="", # Never match the prompt (see below note)
replacement=get_replacement,
),
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
info=PixtralProcessingInfo,
dummy_inputs=PixtralDummyInputsBuilder)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if key in dataclass_fields if key in dataclass_fields
} }
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args) self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM # init MistralForCausalLM
@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
images = kwargs.pop("images", None)
if images is None:
return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
)
def _process_image_input(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.vision_language_adapter(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input, image_tokens = self._parse_and_validate_image_input( image_input = self._parse_and_validate_image_input(**kwargs)
**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings if kwargs.get("v0_path", False):
# from `[IMG_BREAK]` and `[IMG_END]` tokens. return image_features
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image return flatten_2d_lists(
# by the indices of `[IMG_END]` token. self._get_mm_embeds(*args) for args in zip(
image_end_mask = image_tokens == self.vision_args.image_end_token_id image_features,
split_indices = torch.where(image_end_mask)[0] + 1 image_input["num_patches"],
if len(split_indices) <= 1: image_input["embed_is_patch"],
# Do not split, return as tensor of shape [1, fs, hs] ))
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [ input_ids,
self.vision_args.image_token_id, inputs_embeds,
self.vision_args.image_break_token_id, cast(NestedTensors, patch_embeddings),
self.vision_args.image_end_token_id, self.vision_args.image_token_id,
]) )
return inputs_embeds return inputs_embeds
def forward( def forward(
@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for pixtral. """Run forward pass for pixtral."""
"""
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return hidden_states return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None:
return None, None
if isinstance(images, torch.Tensor):
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
if isinstance(image_tokens, torch.Tensor):
# image_tokens are batched
image_tokens = image_tokens.flatten()
elif isinstance(image_tokens, list):
# image_tokens are of different lengths thus passed as a list
image_tokens = torch.cat(image_tokens)
assert image_tokens.dim() == 1
return images, image_tokens
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -400,8 +565,6 @@ class VisionEncoderArgs:
num_attention_heads: int num_attention_heads: int
rope_theta: float # for rope-2D rope_theta: float # for rope-2D
image_token_id: int image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True adapter_bias: bool = True
@ -637,9 +800,13 @@ class VisionTransformer(nn.Module):
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
] ]
patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat( patch_embeds = torch.cat(patch_embeds, dim=1)
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
@ -655,8 +822,8 @@ class VisionTransformer(nn.Module):
"with the Mistral format") "with the Mistral format")
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence # squeeze dim 0 and split into separate tensors for each image
return out.squeeze(0) return torch.split(out.squeeze(0), embed_sizes)
class VisionLanguageAdapter(nn.Module): class VisionLanguageAdapter(nn.Module):
@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: List[torch.Tensor], pixel_values: list[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None, feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> tuple[torch.Tensor, ...]:
""" """
Args: Args:
pixel_values: Each image to be processed will be a separate tensor pixel_values: Each image to be processed will be a separate tensor
@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module):
self.config.num_hidden_layers) self.config.num_hidden_layers)
# squeeze dim 0 and split into separate tensors for each image # squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes) return torch.split(out.squeeze(0), embed_sizes)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986

View File

@ -77,7 +77,9 @@ class PromptIndexTargets:
else: else:
if isinstance(prefix, str): if isinstance(prefix, str):
# Make both `list[int]` # Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix) prefix = encode_tokens(tokenizer,
prefix,
add_special_tokens=False)
match_idx = len(prefix) match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None return match_idx if prompt[:match_idx] == prefix else None
@ -318,7 +320,7 @@ def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
text: str, text: str,
*, *,
add_special_tokens: bool = False, add_special_tokens: Optional[bool] = None,
) -> list[int]: ) -> list[int]:
return encode_tokens(tokenizer, return encode_tokens(tokenizer,
text, text,
@ -330,7 +332,7 @@ def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: tuple[int, ...], token_ids: tuple[int, ...],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
return decode_tokens(tokenizer, return decode_tokens(tokenizer,
list(token_ids), list(token_ids),
@ -395,7 +397,9 @@ class _BoundPromptSequence:
def token_ids(self) -> list[int]: def token_ids(self) -> list[int]:
if self._token_ids is None: if self._token_ids is None:
assert self._text is not None assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text) self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids return self._token_ids
@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]: ) -> Sequence[PromptUpdate]:
""" """
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform. and HF-processed data, output the updates to perform.

View File

@ -34,13 +34,20 @@ def decode_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
token_ids: list[int], token_ids: list[int],
*, *,
skip_special_tokens: bool = False, skip_special_tokens: Optional[bool] = None,
) -> str: ) -> str:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. :code:`tokenizer.decode(token_ids, ...)`.
:code:`skip_special_tokens=None` means to use the backend's default
settings.
""" """
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) if skip_special_tokens is not None:
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids)
def encode_tokens( def encode_tokens(
@ -51,10 +58,14 @@ def encode_tokens(
) -> list[int]: ) -> list[int]:
""" """
Backend-agnostic equivalent of HF's Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`. :code:`tokenizer.encode(text, ...)`.
:code:`add_special_tokens=None` means to use the backend's default
settings.
""" """
if add_special_tokens is not None: if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text) return tokenizer.encode(text)

View File

@ -845,7 +845,7 @@ def is_list_of(
assert_never(check) assert_never(check)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]: def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list.""" """Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]