[V1][VLM][Pixtral-HF] Support Pixtral-HF on V1 (#14275)

Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
lkchen 2025-03-06 00:58:41 -08:00 committed by GitHub
parent 1769928079
commit 5d802522a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 175 additions and 16 deletions

View File

@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `PixtralForConditionalGeneration` - * `PixtralForConditionalGeneration`
* Pixtral * Pixtral
* T + I<sup>+</sup> * T + I<sup>+</sup>
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc. * `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
* *
* ✅︎ * ✅︎
* ✅︎ * ✅︎
@ -930,10 +930,6 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release. Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
::: :::
:::{note}
`mistral-community/pixtral-12b` does not support V1 yet.
:::
:::{note} :::{note}
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
::: :::

View File

@ -4,7 +4,7 @@ from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union) TypedDict, TypeVar, Union, cast)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
"""
num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
"""
num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
@ -317,6 +356,26 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes) for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
tile_sizes = [
get_pixtral_hf_image_feature_grid_size(
hf_config.vision_config,
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2])
for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
@ -324,7 +383,13 @@ class PixtralHFMultiModalProcessor(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict( return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
@ -562,6 +627,23 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
feat_is_patch = kwargs.pop("feat_is_patch", None)
if feat_is_patch is not None and not isinstance(
feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if embed_is_patch is not None and not isinstance(
embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None)
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
@ -571,12 +653,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=flatten_bn(pixel_values), data=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values( data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)), flatten_bn(pixel_values, concat=True)),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
) )
if image_embeds is not None: if image_embeds is not None:
@ -587,6 +675,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=flatten_bn(image_embeds, concat=True),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
@ -633,16 +724,74 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
assert self.vision_tower is not None assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings( if isinstance(image_features, torch.Tensor):
self, **kwargs return self.multi_modal_projector(image_features)
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features
num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings if kwargs.get("v0_path", False):
return vision_embeddings
else:
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -651,8 +800,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index) self.config.image_token_index)
return inputs_embeds return inputs_embeds
@ -705,6 +861,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this # NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility. # condition is for v0 compatibility.
elif inputs_embeds is None: elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)

View File

@ -1484,8 +1484,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
img_patch_id = kwargs.pop("img_patch_id", None) img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor): if not isinstance(img_patch_id, torch.Tensor):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of img_patch_id. "
f"Got type: {type(num_crops)}") f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
return MolmoImageInputs( return MolmoImageInputs(

View File

@ -1042,9 +1042,13 @@ class PixtralHFVisionModel(nn.Module):
for img in pixel_values for img in pixel_values
] ]
patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat( patch_embeds = torch.cat(patch_embeds, dim=1)
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
@ -1075,6 +1079,8 @@ class PixtralHFVisionModel(nn.Module):
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers) self.config.num_hidden_layers)
# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out return out
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded