[VLM] Merged multi-modal processor for Pixtral (#12211)

Signed-off-by: remi <remi@mistral.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Rémi Delacourt 2025-03-15 14:28:27 +01:00 committed by GitHub
parent 74bc397b0a
commit 61c6a5a796
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 620 additions and 358 deletions

View File

@ -43,12 +43,18 @@ from vllm.sampling_params import SamplingParams
# python demo.py advanced
def run_simple_demo():
def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)
# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral")
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"
@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text)
def run_advanced_demo():
def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096
@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = "Describe the following image."
@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'",
)
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
args = parser.parse_args()
if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
run_simple_demo(args)
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()
run_advanced_demo(args)
if __name__ == "__main__":

View File

@ -2,17 +2,23 @@
import copy
from functools import partial
from typing import Optional
from typing import Optional, Union
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import ProcessingCache
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches):
mm_data = {
k:
@ -115,43 +113,131 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
def _test_processing_correctness_hf(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)
assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
baseline_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
assert _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens
# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
# yapf: disable
@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
)
def _drop_mm_kwargs_keys(result: dict,
ignore_mm_keys: Optional[list[str]] = None) -> dict:
def _inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox.

View File

@ -68,23 +68,15 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size, num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@ -360,16 +352,16 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
num_patches = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
# later use `num_patches` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_patches"] = num_patches
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs
@ -378,14 +370,10 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@ -628,27 +616,21 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
num_patches=num_patches,
)
return LlavaImagePixelInputs(
@ -687,21 +669,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels(
self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
@ -731,45 +718,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `feat_is_patch`. This work
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_patches_per_image: list[int] = num_patches.tolist()
num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@ -784,12 +756,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
nested_emb = [
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
vision_embeddings,
image_input["num_patches"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
@ -805,9 +777,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index)
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(

View File

@ -1585,15 +1585,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input)
nested_embeds = [
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
)
]
return flatten_2d_lists(nested_embeds)
))
def get_input_embeddings(
self,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
@ -17,7 +16,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement,
PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -144,7 +143,7 @@ class PaliGemmaMultiModalProcessor(
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

View File

@ -1,26 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import List, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import PixtralVisionConfig
from transformers import PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -31,13 +33,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix,
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -48,132 +57,275 @@ except ImportError:
USE_XFORMERS_OPS = False
def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.instruct.mm_encoder
class PixtralImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
image_config = mm_encoder.mm_config if hasattr(
mm_encoder, "mm_config") else mm_encoder.image_config
max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size
return ((max_image_size // image_patch_size)**2)
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
tokenizer = cached_tokenizer_from_config(ctx.model_config)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.get_mm_config()
num_images = mm_config.get_limit_per_prompt("image")
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image))
image_feature_size = len(encoding.tokens)
num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens),
)
mm_data = {"image": num_images * [image]}
mm_placeholders = {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
return DummyData(seq_data, mm_data, mm_placeholders)
def input_mapper_for_pixtral(ctx: InputContext,
data: object) -> MultiModalKwargs:
"""Maps the input data to its MultiModalKwargs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing PIL images to be processed
and mapped to `images`.
Returns:
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
images: Union[torch.Tensor, list[torch.Tensor]]
"""
tokenizer = cached_tokenizer_from_config(ctx.model_config)
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
data_list = data if isinstance(data, list) else [data]
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
"""
images = []
image_tokens_list = []
for image_data in data_list:
image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
images.append(image)
image_tokens_list.append(encoding.tokens)
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
image_tokens = torch.tensor([
token_id for image_tokens in image_tokens_list
for token_id in image_tokens
])
return MultiModalKwargs({"images": images, "image_tokens": image_tokens})
num_patches: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
prompt_token_ids = inputs.get("prompt_token_ids")
prompt = inputs.get("prompt")
tokenizer = cached_tokenizer_from_config(ctx.model_config)
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
image_break_id = mm_encoder.special_ids.img_break
image_end_id = mm_encoder.special_ids.img_end
self.tokenizer = tokenizer
if image_token_id not in inputs['prompt_token_ids']:
raise ValueError(
f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
# Get precise tracking of placeholder positions
placeholder_ranges = []
curr_offset = -1
curr_length = 0
for i in range(len(prompt_token_ids)):
if prompt_token_ids[i] in (image_token_id, image_break_id):
if curr_offset < 0:
curr_offset = i
curr_length += 1
elif prompt_token_ids[i] == image_end_id:
curr_length += 1
placeholder_ranges.append(
PlaceholderRange(offset=curr_offset, length=curr_length))
curr_offset = -1
curr_length = 0
else:
pass
return token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
@cached_property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@cached_property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@cached_property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@cached_property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@cached_property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
image_token_id = self.image_token_id
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_patches = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_patches.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_patches": torch.tensor(images_num_patches),
}
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config(
self,
processor: Optional[PixtralProcessorAdapter] = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[PixtralProcessorAdapter] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)))
return (ncols + 1) * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_break_id = processor.image_break_id
image_token_id = processor.image_token_id
image_end_id = processor.image_end_id
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_size.width, image_size.height)))
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return tokens
return [
PromptReplacement(
modality="image",
target="", # Never match the prompt (see below note)
replacement=get_replacement,
),
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
info=PixtralProcessingInfo,
dummy_inputs=PixtralDummyInputsBuilder)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@ -191,13 +343,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if key in dataclass_fields
}
if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
@ -221,36 +366,92 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
images = kwargs.pop("images", None)
if images is None:
return None
if not isinstance(images, (torch.Tensor, list)):
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_patches=num_patches,
)
def _process_image_input(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.vision_language_adapter(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_patch, d)
num_patches: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
) -> tuple[torch.Tensor, ...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image: list[int] = num_patches.tolist()
embeds_flat = features.new_full(
(sum(num_patches_per_image), *features.shape[1:]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features
return embeds_flat.split(num_patches_per_image)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
if kwargs.get("v0_path", False):
return image_features
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["num_patches"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
@ -259,12 +460,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
input_ids,
inputs_embeds,
cast(NestedTensors, patch_embeddings),
self.vision_args.image_token_id,
)
return inputs_embeds
def forward(
@ -275,14 +481,14 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for pixtral.
"""
"""Run forward pass for pixtral."""
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
@ -295,47 +501,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return hidden_states
def _parse_and_validate_image_input(
self,
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None:
return None, None
if isinstance(images, torch.Tensor):
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
if isinstance(image_tokens, torch.Tensor):
# image_tokens are batched
image_tokens = image_tokens.flatten()
elif isinstance(image_tokens, list):
# image_tokens are of different lengths thus passed as a list
image_tokens = torch.cat(image_tokens)
assert image_tokens.dim() == 1
return images, image_tokens
def _process_image_input(self,
image_input: List[torch.Tensor]) -> torch.Tensor:
return self.vision_language_adapter(self.vision_encoder(image_input))
def compute_logits(
self,
hidden_states: torch.Tensor,
@ -400,8 +565,6 @@ class VisionEncoderArgs:
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True
@ -637,9 +800,13 @@ class VisionTransformer(nn.Module):
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
]
patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
@ -655,8 +822,8 @@ class VisionTransformer(nn.Module):
"with the Mistral format")
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence
return out.squeeze(0)
# squeeze dim 0 and split into separate tensors for each image
return torch.split(out.squeeze(0), embed_sizes)
class VisionLanguageAdapter(nn.Module):
@ -978,9 +1145,9 @@ class PixtralHFVisionModel(nn.Module):
def forward(
self,
pixel_values: List[torch.Tensor],
pixel_values: list[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
"""
Args:
pixel_values: Each image to be processed will be a separate tensor
@ -1039,8 +1206,7 @@ class PixtralHFVisionModel(nn.Module):
self.config.num_hidden_layers)
# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out
return torch.split(out.squeeze(0), embed_sizes)
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986

View File

@ -77,7 +77,9 @@ class PromptIndexTargets:
else:
if isinstance(prefix, str):
# Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix)
prefix = encode_tokens(tokenizer,
prefix,
add_special_tokens=False)
match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None
@ -318,7 +320,7 @@ def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
@ -330,7 +332,7 @@ def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool = False,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
@ -395,7 +397,9 @@ class _BoundPromptSequence:
def token_ids(self) -> list[int]:
if self._token_ids is None:
assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer, self._text)
self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids
@ -1046,7 +1050,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptUpdate]:
) -> Sequence[PromptUpdate]:
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.

View File

@ -34,13 +34,20 @@ def decode_tokens(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
skip_special_tokens: Optional[bool] = None,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
:code:`tokenizer.decode(token_ids, ...)`.
:code:`skip_special_tokens=None` means to use the backend's default
settings.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
if skip_special_tokens is not None:
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids)
def encode_tokens(
@ -51,10 +58,14 @@ def encode_tokens(
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
:code:`tokenizer.encode(text, ...)`.
:code:`add_special_tokens=None` means to use the backend's default
settings.
"""
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)

View File

@ -845,7 +845,7 @@ def is_list_of(
assert_never(check)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]