1000 lines
37 KiB
Python
1000 lines
37 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from abc import abstractmethod
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import cached_property
|
|
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
|
|
TypeVar, Union, cast)
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from packaging.version import Version
|
|
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
|
PixtralVisionConfig, PretrainedConfig,
|
|
SiglipVisionConfig)
|
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
|
from transformers.models.llava import LlavaProcessor
|
|
from transformers.models.pixtral import PixtralProcessor
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.inputs import InputProcessingContext
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
MultiModalInputs, MultiModalKwargs,
|
|
NestedTensors)
|
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|
ImageSize, MultiModalDataItems)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo, ProcessingCache,
|
|
PromptReplacement, PromptUpdate)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
|
|
|
from .clip import CLIPVisionModel
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
|
from .siglip import SiglipVisionModel
|
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
|
maybe_prefix, merge_multimodal_embeddings)
|
|
from .vision import get_vision_encoder_info
|
|
|
|
|
|
class LlavaImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
pixel_values: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
|
|
|
Note that `height` or `width` may be different per batch and image,
|
|
in which case the data is passed as a list instead of a batched tensor.
|
|
"""
|
|
|
|
|
|
class PixtralHFImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values_pixtral"]
|
|
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
Shape: `(batch_size * num_images, num_channels, height, width)`
|
|
|
|
Note that `height` or `width` may be different per batch and image,
|
|
in which case the data is passed as a list instead of a batched tensor.
|
|
"""
|
|
|
|
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
A boolean mask indicating which image features correspond
|
|
to patch tokens.
|
|
|
|
Shape: `(batch_size, num_crops, num_patch)`
|
|
"""
|
|
|
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
A boolean mask indicating which image embeddings correspond
|
|
to patch tokens.
|
|
|
|
Shape: `(batch_size, num_embeds)`
|
|
"""
|
|
|
|
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""Shape: `(batch_size, num_images)`"""
|
|
|
|
|
|
class LlavaImageEmbeddingInputs(TypedDict):
|
|
type: Literal["image_embeds"]
|
|
data: torch.Tensor
|
|
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
|
|
|
`hidden_size` must match the hidden size of language model backbone.
|
|
"""
|
|
|
|
|
|
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
|
|
LlavaImageEmbeddingInputs]
|
|
|
|
|
|
class LlavaMultiModalProjector(nn.Module):
|
|
|
|
def __init__(self,
|
|
vision_hidden_size: int,
|
|
text_hidden_size: int,
|
|
projector_hidden_act: str,
|
|
multimodal_projector_bias: bool,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = ""):
|
|
super().__init__()
|
|
|
|
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
|
|
text_hidden_size,
|
|
bias=multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_1")
|
|
self.act = get_act_fn(projector_hidden_act)
|
|
self.linear_2 = RowParallelLinear(text_hidden_size,
|
|
text_hidden_size,
|
|
bias=multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_2")
|
|
|
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.linear_1(image_features)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class LlavaLikeConfig(Protocol):
|
|
vision_config: Final[PretrainedConfig]
|
|
image_token_index: Final[int]
|
|
vision_feature_select_strategy: Final[str]
|
|
vision_feature_layer: Final[Union[int, list[int]]]
|
|
|
|
|
|
class LlavaLikeProcessor(Protocol):
|
|
image_token: Final[str]
|
|
|
|
|
|
class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_hf_config(self) -> LlavaLikeConfig:
|
|
return self.ctx.get_hf_config(LlavaConfig)
|
|
|
|
def get_vision_encoder_info(self):
|
|
return get_vision_encoder_info(self.get_hf_config())
|
|
|
|
@abstractmethod
|
|
def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
|
|
raise NotImplementedError
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": None}
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int]:
|
|
return {"image": self.get_max_image_tokens()}
|
|
|
|
def _apply_feature_select_strategy(
|
|
self,
|
|
strategy: str,
|
|
encoder_num_image_tokens: int,
|
|
) -> int:
|
|
if strategy == "default":
|
|
return encoder_num_image_tokens - 1
|
|
if strategy == "full":
|
|
return encoder_num_image_tokens
|
|
|
|
msg = f"Unexpected feature select strategy: {strategy!r}"
|
|
raise NotImplementedError(msg)
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
hf_config = self.get_hf_config()
|
|
vision_encoder_info = self.get_vision_encoder_info()
|
|
|
|
return self._apply_feature_select_strategy(
|
|
hf_config.vision_feature_select_strategy,
|
|
vision_encoder_info.get_num_image_tokens(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
),
|
|
)
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
vision_encoder_info = self.get_vision_encoder_info()
|
|
width = height = vision_encoder_info.get_image_size()
|
|
return ImageSize(width=width, height=height)
|
|
|
|
def get_max_image_tokens(self) -> int:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
return self.get_num_image_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
)
|
|
|
|
|
|
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
|
|
|
|
|
|
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|
|
|
def get_dummy_processor_inputs(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> ProcessorInputs:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
processor = self.info.get_hf_processor()
|
|
image_token = processor.image_token
|
|
target_width, target_height = \
|
|
self.info.get_image_size_with_most_features()
|
|
|
|
mm_data = {
|
|
"image":
|
|
self._get_dummy_images(width=target_width,
|
|
height=target_height,
|
|
num_images=num_images)
|
|
}
|
|
|
|
return ProcessorInputs(
|
|
prompt_text=image_token * num_images,
|
|
mm_data=mm_data,
|
|
)
|
|
|
|
|
|
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
|
|
|
|
def get_hf_processor(self, **kwargs: object):
|
|
return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
|
|
|
|
|
|
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|
|
|
# Copied from BaseMultiModalProcessor
|
|
@abstractmethod
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
raise NotImplementedError
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_config = self.info.get_hf_config()
|
|
image_token_id = hf_config.image_token_index
|
|
|
|
def get_replacement(item_idx: int):
|
|
images = mm_items.get_items(
|
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
|
|
|
if isinstance(images, ImageEmbeddingItems):
|
|
num_image_tokens = images.get_feature_size(item_idx)
|
|
else:
|
|
image_size = images.get_image_size(item_idx)
|
|
num_image_tokens = self.info.get_num_image_tokens(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
)
|
|
|
|
return [image_token_id] * num_image_tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id],
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
|
|
class LlavaMultiModalProcessor(
|
|
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
)
|
|
|
|
|
|
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
|
|
|
|
def get_hf_processor(self, **kwargs: object):
|
|
return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
|
|
|
|
|
|
class PixtralHFMultiModalProcessor(
|
|
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
)
|
|
|
|
pixel_values = processed_outputs.get("pixel_values")
|
|
if pixel_values is not None:
|
|
# Before/after https://github.com/huggingface/transformers/pull/35122
|
|
if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
|
|
images = mm_data["images"]
|
|
assert isinstance(images, list)
|
|
|
|
# Original output: (1, num_images, C, H, W)
|
|
# New output: (num_images, C, H, W)
|
|
assert (isinstance(pixel_values, list)
|
|
and len(pixel_values) == 1)
|
|
assert (isinstance(pixel_values[0], list)
|
|
and len(pixel_values[0]) == len(images))
|
|
|
|
processed_outputs["pixel_values"] = pixel_values[0]
|
|
else:
|
|
# Avoid padding since we need the output for each image to be
|
|
# independent of other images for the cache to work correctly
|
|
image_sizes = processed_outputs["image_sizes"]
|
|
assert len(pixel_values) == len(image_sizes)
|
|
|
|
processed_outputs["pixel_values"] = [
|
|
p[:, :h, :w]
|
|
for p, (h, w) in zip(pixel_values, image_sizes)
|
|
]
|
|
|
|
hf_config = self.info.get_hf_config()
|
|
vision_config = hf_config.vision_config
|
|
assert isinstance(vision_config, PixtralVisionConfig)
|
|
encoder_info = PixtralHFEncoderInfo(vision_config)
|
|
|
|
tile_sizes = [
|
|
encoder_info.get_patch_grid_size(
|
|
image_width=pixel_value.shape[-1],
|
|
image_height=pixel_value.shape[-2],
|
|
) for pixel_value in processed_outputs["pixel_values"]
|
|
]
|
|
num_crops = torch.tensor([(ncols + 1) * nrows
|
|
for ncols, nrows in tile_sizes])
|
|
# Each image may result to masks of different sizes, so we need to
|
|
# flatten the list and later use `num_crops` to get per-image masks.
|
|
embed_is_patch = torch.tensor(
|
|
flatten_2d_lists([([True] * ncols + [False]) * nrows
|
|
for ncols, nrows in tile_sizes]))
|
|
processed_outputs["num_crops"] = num_crops
|
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
|
processed_outputs["feat_is_patch"] = embed_is_patch
|
|
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
|
|
return dict(
|
|
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", num_crops),
|
|
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", num_crops),
|
|
num_crops=MultiModalFieldConfig.batched("image"),
|
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
hf_config = self.info.get_hf_config()
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
image_break_id = vocab[processor.image_break_token]
|
|
image_token_id = hf_config.image_token_index
|
|
image_end_id = vocab[processor.image_end_token]
|
|
|
|
vision_config = hf_config.vision_config
|
|
assert isinstance(vision_config, PixtralVisionConfig)
|
|
encoder_info = PixtralHFEncoderInfo(vision_config)
|
|
|
|
def get_replacement(item_idx: int):
|
|
images = mm_items.get_items("image", ImageProcessorItems)
|
|
image_size = images.get_image_size(item_idx)
|
|
|
|
ncols, nrows = encoder_info.get_patch_grid_size(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
)
|
|
|
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
|
tokens[-1] = image_end_id
|
|
|
|
return tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id],
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
|
|
def _build_llava_or_pixtral_hf_info(
|
|
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
|
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
|
|
|
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
|
return PixtralHFProcessingInfo(ctx)
|
|
|
|
return LlavaProcessingInfo(ctx)
|
|
|
|
|
|
def _build_llava_or_pixtral_hf_processor(
|
|
info: _I,
|
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
|
*,
|
|
cache: Optional[ProcessingCache] = None,
|
|
enable_sanity_checks: bool = True,
|
|
) -> BaseMultiModalProcessor:
|
|
if isinstance(info, PixtralHFProcessingInfo):
|
|
return PixtralHFMultiModalProcessor(
|
|
info,
|
|
dummy_inputs, # type: ignore
|
|
cache=cache,
|
|
enable_sanity_checks=enable_sanity_checks,
|
|
)
|
|
|
|
if isinstance(info, LlavaProcessingInfo):
|
|
return LlavaMultiModalProcessor(
|
|
info,
|
|
dummy_inputs, # type: ignore
|
|
cache=cache,
|
|
enable_sanity_checks=enable_sanity_checks,
|
|
)
|
|
|
|
raise NotImplementedError(type(info))
|
|
|
|
|
|
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
|
"""Determine the number of hidden layers to initialize up to in the
|
|
visual encoder.
|
|
|
|
Args:
|
|
hf_config: Model config with vision feature layer(s).
|
|
"""
|
|
feature_layers = hf_config.vision_feature_layer
|
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
|
# If we have one feature layer, initialize up to that layer
|
|
if isinstance(feature_layers, int):
|
|
return _get_layer_index(feature_layers, num_hidden_layers)
|
|
# If we have multiple feature layers, initialize up to the deepest one
|
|
elif isinstance(feature_layers, (list, tuple)):
|
|
return max(
|
|
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
|
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
|
" is not supported")
|
|
|
|
|
|
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
|
"""Given a signed vision feature layer, get the number of hidden layers
|
|
needed to leverage it.
|
|
|
|
Args:
|
|
feature_layer_index: Index of a required layer in the visual encoder.
|
|
num_hidden_layers: The total number of hidden layers in the visual
|
|
encoder.
|
|
"""
|
|
if feature_layer_index < 0:
|
|
return num_hidden_layers + feature_layer_index + 1
|
|
return feature_layer_index
|
|
|
|
|
|
def init_vision_tower_for_llava(
|
|
hf_config: LlavaLikeConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
*,
|
|
require_post_norm: Optional[bool] = None,
|
|
prefix: str = "",
|
|
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
|
|
vision_config = hf_config.vision_config
|
|
|
|
# Initialize the vision tower only up to the deepest required feature layer
|
|
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
return CLIPVisionModel(
|
|
vision_config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
require_post_norm=require_post_norm,
|
|
prefix=prefix,
|
|
)
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
return SiglipVisionModel(
|
|
vision_config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
require_post_norm=require_post_norm,
|
|
prefix=prefix,
|
|
)
|
|
elif isinstance(vision_config, PixtralVisionConfig):
|
|
return PixtralHFVisionModel(
|
|
vision_config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
require_post_norm=require_post_norm,
|
|
prefix=prefix,
|
|
)
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
|
|
info=_build_llava_or_pixtral_hf_info,
|
|
dummy_inputs=LlavaDummyInputsBuilder)
|
|
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
|
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
|
if (config.text_config.architectures is None
|
|
and config.text_config.model_type == "mistral"):
|
|
config.text_config.architectures = ["MistralForCausalLM"]
|
|
if (config.projector_hidden_act is None
|
|
and config.vision_config.hidden_act == "gelu"):
|
|
config.projector_hidden_act = "gelu"
|
|
|
|
# TODO: Optionally initializes this for supporting embeddings.
|
|
self.vision_tower = init_vision_tower_for_llava(
|
|
config,
|
|
quant_config,
|
|
require_post_norm=False,
|
|
prefix=maybe_prefix(prefix, "vision_tower"))
|
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
projector_hidden_act=config.projector_hidden_act,
|
|
multimodal_projector_bias=config.multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
@cached_property
|
|
def sampler(self):
|
|
if hasattr(self.language_model, "sampler"):
|
|
return self.language_model.sampler
|
|
|
|
return get_sampler()
|
|
|
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
h = w = self.config.vision_config.image_size
|
|
expected_dims = (3, h, w)
|
|
actual_dims = tuple(data.shape[1:])
|
|
|
|
if actual_dims != expected_dims:
|
|
expected_expr = ("batch_size", *map(str, expected_dims))
|
|
raise ValueError(
|
|
f"The expected shape of pixel values is {expected_expr}. "
|
|
f"You supplied {tuple(data.shape)}.")
|
|
|
|
return data
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
if pixel_values is not None:
|
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
if self.config.vision_config.model_type == "pixtral":
|
|
feat_is_patch = kwargs.pop("feat_is_patch")
|
|
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of feat_is_patch. "
|
|
f"Got type: {type(feat_is_patch)}")
|
|
|
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of embed_is_patch. "
|
|
f"Got type: {type(embed_is_patch)}")
|
|
|
|
num_crops = kwargs.pop("num_crops")
|
|
if not isinstance(num_crops, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of num_crops. "
|
|
f"Got type: {type(num_crops)}")
|
|
|
|
return PixtralHFImagePixelInputs(
|
|
type="pixel_values_pixtral",
|
|
pixel_values=flatten_bn(pixel_values),
|
|
feat_is_patch=feat_is_patch,
|
|
embed_is_patch=embed_is_patch,
|
|
num_crops=num_crops,
|
|
)
|
|
|
|
return LlavaImagePixelInputs(
|
|
type="pixel_values",
|
|
pixel_values=self._validate_pixel_values(
|
|
flatten_bn(pixel_values, concat=True)),
|
|
)
|
|
|
|
if image_embeds is not None:
|
|
if not isinstance(image_embeds, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of image embeddings. "
|
|
f"Got type: {type(image_embeds)}")
|
|
|
|
if self.config.vision_config.model_type == "pixtral":
|
|
raise ValueError("Pixtral-HF does not support image_embeds.")
|
|
|
|
return LlavaImageEmbeddingInputs(
|
|
type="image_embeds",
|
|
data=flatten_bn(image_embeds, concat=True),
|
|
)
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
|
strategy: str) -> torch.Tensor:
|
|
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
|
if strategy == "default":
|
|
return image_features[:, 1:]
|
|
elif strategy == "full":
|
|
return image_features
|
|
|
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
|
|
|
def _image_pixels_to_features(
|
|
self,
|
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
|
PixtralHFVisionModel],
|
|
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
|
) -> torch.Tensor:
|
|
|
|
# NOTE: we skip the step to select the vision feature layer since
|
|
# this is already done inside the vision tower
|
|
image_features = vision_tower(pixel_values)
|
|
|
|
return self._select_image_features(
|
|
image_features,
|
|
strategy=self.config.vision_feature_select_strategy,
|
|
)
|
|
|
|
def _process_image_pixels(
|
|
self,
|
|
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
|
) -> torch.Tensor:
|
|
assert self.vision_tower is not None
|
|
|
|
pixel_values = inputs["pixel_values"]
|
|
|
|
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: LlavaImageInputs,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
|
if image_input["type"] == "image_embeds":
|
|
return image_input["data"]
|
|
|
|
assert self.vision_tower is not None
|
|
image_features = self._process_image_pixels(image_input)
|
|
|
|
if isinstance(image_features, torch.Tensor):
|
|
return self.multi_modal_projector(image_features)
|
|
|
|
feature_sizes = [
|
|
image_feature.shape[0] for image_feature in image_features
|
|
]
|
|
|
|
image_embeds = self.multi_modal_projector(torch.cat(image_features))
|
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
|
return image_embeds
|
|
|
|
def _get_mm_embeds(
|
|
self,
|
|
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
|
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
|
num_crops: torch.Tensor, # Shape: (num_images,)
|
|
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
|
) -> list[torch.Tensor]:
|
|
"""Scatter the patch features into a contiguous tensor that corresponds
|
|
to the embedding tokens defined by the multimodal processor.
|
|
|
|
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
|
"""
|
|
|
|
# Insert columns of nan values according to `feat_is_patch`. This work
|
|
# ideally should be done in `_process_image_input`, but
|
|
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
|
# put the logic here.
|
|
# FIXME: Move this logic to `_process_image_input` when v0 is
|
|
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
|
feat_is_patch = feat_is_patch.view(-1)
|
|
embed_is_patch = embed_is_patch.view(-1)
|
|
expanded_embedding = torch.full(
|
|
(sum(num_crops), *features.shape[1:]),
|
|
torch.nan,
|
|
dtype=features.dtype).to(features.device)
|
|
expanded_embedding[feat_is_patch] = features
|
|
|
|
num_crops_per_image = num_crops.tolist()
|
|
feats_per_image = expanded_embedding.split(num_crops_per_image)
|
|
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
|
|
|
embed_dim = expanded_embedding.shape[-1]
|
|
num_embeds = embed_is_patch.shape[0]
|
|
|
|
embeds_in_batch = list[torch.Tensor]()
|
|
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
|
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
|
embeds[embed_is_patch] = feats[f_is_patch]
|
|
embeds_in_batch.append(embeds)
|
|
|
|
return embeds_in_batch
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
if (kwargs.get("v0_path", False)
|
|
or image_input["type"] != "pixel_values_pixtral"):
|
|
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
|
return vision_embeddings
|
|
|
|
nested_emb = [
|
|
self._get_mm_embeds(*args) for args in zip(
|
|
vision_embeddings, image_input["feat_is_patch"],
|
|
image_input["num_crops"], image_input["embed_is_patch"])
|
|
]
|
|
return flatten_2d_lists(nested_emb)
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
# Extract the patch tokens
|
|
patch_embeddings = json_map_leaves(
|
|
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
|
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
|
)
|
|
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, cast(NestedTensors,
|
|
patch_embeddings),
|
|
self.config.image_token_index)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
"""Run forward pass for LLaVA-1.5.
|
|
|
|
One key thing to understand is the `input_ids` already accounts for the
|
|
positions of the to-be-inserted image embeddings.
|
|
|
|
Concretely, consider a text prompt:
|
|
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
|
|
|
|
Tokenizer outputs:
|
|
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
|
|
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
|
|
|
|
To reserve space in KV cache, we have to insert placeholder tokens
|
|
before they are inputted to the model, so the input processor prepends
|
|
additional image tokens (denoted as `32000`), resulting in:
|
|
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
|
|
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
|
|
29901]`.
|
|
|
|
We insert 575 tokens so that including the original image token in the
|
|
input, there are a total of 576 (24 * 24) image tokens, which
|
|
corresponds to the number of image tokens inputted to the language
|
|
model, i.e. the number of image tokens outputted by the visual encoder.
|
|
|
|
This way, the `positions` and `attn_metadata` are consistent
|
|
with the `input_ids`.
|
|
|
|
Args:
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
batch.
|
|
pixel_values: The pixels in each input image.
|
|
|
|
See also:
|
|
:class:`LlavaImageInputs`
|
|
"""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility.
|
|
elif inputs_embeds is None:
|
|
kwargs.update({"v0_path": True})
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|
|
|
|
|
|
class MantisProcessingInfo(LlavaProcessingInfo):
|
|
|
|
def get_hf_processor(self, **kwargs: object):
|
|
hf_config = self.get_hf_config()
|
|
vision_info = self.get_vision_encoder_info()
|
|
|
|
kwargs.setdefault("patch_size", vision_info.get_patch_size())
|
|
|
|
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
|
|
# BUG: num_additional_image_tokens = 0 but treated as 1,
|
|
# so we set vision_feature_select_strategy to None to offset this
|
|
kwargs.setdefault("vision_feature_select_strategy", None)
|
|
else:
|
|
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
|
|
kwargs.setdefault(
|
|
"vision_feature_select_strategy",
|
|
hf_config.vision_feature_select_strategy,
|
|
)
|
|
|
|
return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
|
|
|
|
|
|
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
|
|
|
def apply(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data: MultiModalDataDict,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
return_mm_hashes: bool = False,
|
|
) -> MultiModalInputs:
|
|
hf_config = self.info.get_hf_config()
|
|
image_token_id = hf_config.image_token_index
|
|
|
|
# Assume that it doesn't depend on the image size
|
|
num_image_tokens = self.info.get_num_image_tokens(
|
|
image_width=-1,
|
|
image_height=-1,
|
|
)
|
|
|
|
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
|
return_mm_hashes)
|
|
|
|
mm_items = self._to_mm_items(mm_data)
|
|
mm_item_counts = mm_items.get_all_counts()
|
|
mm_kwargs = result["mm_kwargs"]
|
|
|
|
# We reimplement the functionality of MLlavaProcessor from
|
|
# https://github.com/TIGER-AI-Lab/Mantis.git
|
|
def get_replacement_mantis(item_idx: int):
|
|
return "".join([
|
|
f"(image {item_idx+1}: <Image>", # 7 tokens
|
|
"<image>" * num_image_tokens,
|
|
"</Image>)", # 3 tokens
|
|
])
|
|
|
|
mantis_mm_repls = self._bind_and_group_updates([
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id] * num_image_tokens,
|
|
replacement=get_replacement_mantis,
|
|
)
|
|
])
|
|
|
|
prompt_ids, prompt, _ = self._apply_prompt_updates(
|
|
result["prompt_token_ids"],
|
|
mantis_mm_repls,
|
|
mm_item_counts,
|
|
)
|
|
|
|
unbound_orig_repls = self._get_prompt_updates(
|
|
mm_items,
|
|
hf_processor_mm_kwargs,
|
|
mm_kwargs,
|
|
)
|
|
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
|
|
|
|
mm_placeholders = self._find_mm_placeholders(
|
|
orig_repls,
|
|
prompt_ids,
|
|
mm_item_counts,
|
|
)
|
|
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
|
|
|
mm_placeholder_ranges = {
|
|
modality: [item.to_range() for item in placeholders]
|
|
for modality, placeholders in mm_placeholders.items()
|
|
}
|
|
|
|
return MultiModalInputs(
|
|
type="multimodal",
|
|
prompt=prompt,
|
|
prompt_token_ids=prompt_ids,
|
|
mm_kwargs=mm_kwargs,
|
|
mm_placeholders=mm_placeholder_ranges,
|
|
)
|
|
|
|
|
|
# To use this model, please use
|
|
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
|
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
|
info=MantisProcessingInfo,
|
|
dummy_inputs=LlavaDummyInputsBuilder)
|
|
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
|
pass
|