[V1][VLM][Pixtral-HF] Support Pixtral-HF on V1 (#14275)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
parent
1769928079
commit
5d802522a7
@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- * `PixtralForConditionalGeneration`
|
- * `PixtralForConditionalGeneration`
|
||||||
* Pixtral
|
* Pixtral
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc.
|
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
|
||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
@ -930,10 +930,6 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
|||||||
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
|
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
:::{note}
|
|
||||||
`mistral-community/pixtral-12b` does not support V1 yet.
|
|
||||||
:::
|
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
|
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
|
||||||
:::
|
:::
|
||||||
|
@ -4,7 +4,7 @@ from abc import abstractmethod
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||||
TypedDict, TypeVar, Union)
|
TypedDict, TypeVar, Union, cast)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptReplacement, PromptUpdate)
|
PromptReplacement, PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
|
|||||||
in which case the data is passed as a list instead of a batched tensor.
|
in which case the data is passed as a list instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image features correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_crops, num_patch)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_crops: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
class LlavaImageEmbeddingInputs(TypedDict):
|
class LlavaImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
|
|||||||
`hidden_size` must match the hidden size of language model backbone.
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image features correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_crops, num_patch)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_crops: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||||
|
|
||||||
@ -317,6 +356,26 @@ class PixtralHFMultiModalProcessor(
|
|||||||
for p, (h, w) in zip(pixel_values, image_sizes)
|
for p, (h, w) in zip(pixel_values, image_sizes)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
|
||||||
|
tile_sizes = [
|
||||||
|
get_pixtral_hf_image_feature_grid_size(
|
||||||
|
hf_config.vision_config,
|
||||||
|
image_width=pixel_value.shape[-1],
|
||||||
|
image_height=pixel_value.shape[-2])
|
||||||
|
for pixel_value in processed_outputs["pixel_values"]
|
||||||
|
]
|
||||||
|
num_crops = torch.tensor([(ncols + 1) * nrows
|
||||||
|
for ncols, nrows in tile_sizes])
|
||||||
|
# Each image may result to masks of different sizes, so we need to
|
||||||
|
# flatten the list and later use `num_crops` to get per-image masks.
|
||||||
|
embed_is_patch = torch.tensor(
|
||||||
|
flatten_2d_lists([([True] * ncols + [False]) * nrows
|
||||||
|
for ncols, nrows in tile_sizes]))
|
||||||
|
processed_outputs["num_crops"] = num_crops
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
processed_outputs["feat_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
@ -324,7 +383,13 @@ class PixtralHFMultiModalProcessor(
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
|
||||||
return dict(
|
return dict(
|
||||||
|
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", num_crops),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", num_crops),
|
||||||
|
num_crops=MultiModalFieldConfig.batched("image"),
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
@ -562,6 +627,23 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
||||||
|
if feat_is_patch is not None and not isinstance(
|
||||||
|
feat_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of feat_is_patch. "
|
||||||
|
f"Got type: {type(feat_is_patch)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
|
if embed_is_patch is not None and not isinstance(
|
||||||
|
embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
|
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of num_crops. "
|
||||||
|
f"Got type: {type(num_crops)}")
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
@ -571,12 +653,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=flatten_bn(pixel_values),
|
data=flatten_bn(pixel_values),
|
||||||
|
feat_is_patch=feat_is_patch,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_crops=num_crops,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
data=self._validate_pixel_values(
|
||||||
flatten_bn(pixel_values, concat=True)),
|
flatten_bn(pixel_values, concat=True)),
|
||||||
|
feat_is_patch=feat_is_patch,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_crops=num_crops,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
@ -587,6 +675,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return LlavaImageEmbeddingInputs(
|
return LlavaImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=flatten_bn(image_embeds, concat=True),
|
data=flatten_bn(image_embeds, concat=True),
|
||||||
|
feat_is_patch=feat_is_patch,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_crops=num_crops,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
@ -633,16 +724,74 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
image_features = self._process_image_pixels(image_input)
|
image_features = self._process_image_pixels(image_input)
|
||||||
return self.multi_modal_projector(image_features)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
if isinstance(image_features, torch.Tensor):
|
||||||
self, **kwargs
|
return self.multi_modal_projector(image_features)
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
|
||||||
|
feature_sizes = [
|
||||||
|
image_feature.shape[0] for image_feature in image_features
|
||||||
|
]
|
||||||
|
|
||||||
|
image_embeds = self.multi_modal_projector(torch.cat(image_features))
|
||||||
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def _get_mm_embeds(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
||||||
|
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||||
|
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||||
|
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||||
|
to the embedding tokens defined by the multimodal processor.
|
||||||
|
|
||||||
|
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Insert columns of nan values according to `feat_is_patch`. This work
|
||||||
|
# ideally should be done in `_process_image_input`, but
|
||||||
|
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||||
|
# put the logic here.
|
||||||
|
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||||
|
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||||
|
feat_is_patch = feat_is_patch.view(-1)
|
||||||
|
embed_is_patch = embed_is_patch.view(-1)
|
||||||
|
expanded_embedding = torch.full(
|
||||||
|
(sum(num_crops), *features.shape[1:]),
|
||||||
|
torch.nan,
|
||||||
|
dtype=features.dtype).to(features.device)
|
||||||
|
expanded_embedding[feat_is_patch] = features
|
||||||
|
|
||||||
|
num_crops_per_image = num_crops.tolist()
|
||||||
|
feats_per_image = expanded_embedding.split(num_crops_per_image)
|
||||||
|
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||||
|
|
||||||
|
embed_dim = expanded_embedding.shape[-1]
|
||||||
|
num_embeds = embed_is_patch.shape[0]
|
||||||
|
|
||||||
|
embeds_in_batch = list[torch.Tensor]()
|
||||||
|
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
||||||
|
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
||||||
|
embeds[embed_is_patch] = feats[f_is_patch]
|
||||||
|
embeds_in_batch.append(embeds)
|
||||||
|
|
||||||
|
return embeds_in_batch
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
if kwargs.get("v0_path", False):
|
||||||
|
return vision_embeddings
|
||||||
|
else:
|
||||||
|
nested_emb = [
|
||||||
|
self._get_mm_embeds(*args) for args in zip(
|
||||||
|
vision_embeddings, image_input["feat_is_patch"],
|
||||||
|
image_input["num_crops"], image_input["embed_is_patch"])
|
||||||
|
]
|
||||||
|
return flatten_2d_lists(nested_emb)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -651,8 +800,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
|
# Extract the patch tokens
|
||||||
|
patch_embeddings = json_map_leaves(
|
||||||
|
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||||
|
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||||
|
)
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids, inputs_embeds, cast(NestedTensors,
|
||||||
|
patch_embeddings),
|
||||||
self.config.image_token_index)
|
self.config.image_token_index)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
@ -705,6 +861,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
kwargs.update({"v0_path": True})
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
vision_embeddings)
|
vision_embeddings)
|
||||||
|
@ -1484,8 +1484,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
|
|
||||||
img_patch_id = kwargs.pop("img_patch_id", None)
|
img_patch_id = kwargs.pop("img_patch_id", None)
|
||||||
if not isinstance(img_patch_id, torch.Tensor):
|
if not isinstance(img_patch_id, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of num_crops. "
|
raise ValueError("Incorrect type of img_patch_id. "
|
||||||
f"Got type: {type(num_crops)}")
|
f"Got type: {type(img_patch_id)}")
|
||||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||||
|
|
||||||
return MolmoImageInputs(
|
return MolmoImageInputs(
|
||||||
|
@ -1042,9 +1042,13 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
for img in pixel_values
|
for img in pixel_values
|
||||||
]
|
]
|
||||||
|
|
||||||
|
patch_embeds = [
|
||||||
|
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
||||||
|
]
|
||||||
|
embed_sizes = [p.shape[1] for p in patch_embeds]
|
||||||
|
|
||||||
# flatten to a single sequence
|
# flatten to a single sequence
|
||||||
patch_embeds = torch.cat(
|
patch_embeds = torch.cat(patch_embeds, dim=1)
|
||||||
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
|
||||||
patch_embeds = self.ln_pre(patch_embeds)
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
# positional embeddings
|
# positional embeddings
|
||||||
@ -1075,6 +1079,8 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
||||||
self.config.num_hidden_layers)
|
self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# squeeze dim 0 and split into separate tensors for each image
|
||||||
|
out = torch.split(torch.squeeze(out), embed_sizes)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||||
|
Loading…
x
Reference in New Issue
Block a user