733 lines
28 KiB
Python
733 lines
28 KiB
Python
# Copyright 2024 The vLLM team.
|
|
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import cached_property
|
|
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
|
ProcessorMixin)
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
MultiModalInputsV2, MultiModalKwargs,
|
|
NestedTensors, PlaceholderRange)
|
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|
ImageSize, MultiModalDataItems)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
BoundPromptReplacement,
|
|
PlaceholderInfo, PromptReplacement)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import is_list_of
|
|
|
|
from .clip import CLIPVisionModel
|
|
from .interfaces import SupportsMultiModal, SupportsPP
|
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|
init_vllm_registered_model, maybe_prefix,
|
|
merge_multimodal_embeddings)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Cannot find the following 2 numbers from hf config.
|
|
_IMAGE_TOKEN_ID = 32044
|
|
|
|
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
|
hidden_act="quick_gelu",
|
|
hidden_size=1024,
|
|
image_size=336,
|
|
intermediate_size=4096,
|
|
num_attention_heads=16,
|
|
num_channels=3,
|
|
num_hidden_layers=24,
|
|
patch_size=14,
|
|
projection_dim=768)
|
|
|
|
|
|
def _init_img_processor(hf_config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "") -> CLIPVisionModel:
|
|
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
|
layer_idx = hf_config.img_processor.get('layer_idx', -2)
|
|
|
|
# Initialize the CLIP only up to the required feature layer
|
|
if layer_idx < 0:
|
|
num_hidden_layers = clip_config.num_hidden_layers + \
|
|
layer_idx + 1
|
|
else:
|
|
num_hidden_layers = layer_idx + 1
|
|
|
|
img_processor = CLIPVisionModel(
|
|
clip_config,
|
|
quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
prefix=prefix,
|
|
)
|
|
|
|
return img_processor
|
|
|
|
|
|
class Phi3VImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
|
"""
|
|
Shape:
|
|
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
|
|
|
Note that `num_patches` may be different per batch and image,
|
|
in which case the data is passed as a list instead of a batched tensor.
|
|
"""
|
|
|
|
image_sizes: torch.Tensor
|
|
"""
|
|
Shape: `(batch_size * num_images, 2)`
|
|
|
|
This should be in `(height, width)` format.
|
|
"""
|
|
|
|
|
|
class Phi3VImageEmbeddingInputs(TypedDict):
|
|
type: Literal["image_embeds"]
|
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
|
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
|
|
|
`hidden_size` must match the hidden size of language model backbone.
|
|
"""
|
|
|
|
|
|
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
|
|
|
|
|
|
class Phi3ImageEmbeddingBase(nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layer_idx: int
|
|
self.type_feature: str
|
|
self.img_processor: CLIPVisionModel
|
|
|
|
def get_img_features(self,
|
|
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
|
TYPE_FEATURE = self.type_feature
|
|
|
|
# NOTE: we skip the step to select the vision feature layer since
|
|
# this is already done inside the img_processor
|
|
img_feature = self.img_processor(img_embeds)
|
|
|
|
if TYPE_FEATURE == "patch":
|
|
patch_feature = img_feature[:, 1:]
|
|
return patch_feature
|
|
|
|
if TYPE_FEATURE == "cls_patch":
|
|
return img_feature
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
|
|
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|
"""Phi3 Image embedding with HD transform."""
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "") -> None:
|
|
super().__init__()
|
|
|
|
# n_embed or hidden_size
|
|
hidden_size = config.n_embd if hasattr(
|
|
config, 'n_embd') else config.hidden_size
|
|
|
|
self.img_processor = _init_img_processor(
|
|
config, quant_config, prefix=f"{prefix}.img_processor")
|
|
|
|
image_dim_out = config.img_processor['image_dim_out']
|
|
self.num_img_tokens = config.img_processor['num_img_tokens']
|
|
|
|
self.image_dim_out = image_dim_out
|
|
|
|
# global_gn and sub_gn for hd transform, serves as line separator
|
|
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
|
|
False)
|
|
self.with_learnable_separator = config.embd_layer.get(
|
|
'with_learnable_separator', False)
|
|
self.hd_transform_order = config.embd_layer.get(
|
|
'hd_transform_order', 'glb_sub')
|
|
# with_hd_transform and with_learnable_separator should have same value
|
|
assert self.use_hd_transform and self.with_learnable_separator
|
|
|
|
# 1024 * 4, merge spatial to channel dimension
|
|
self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
|
|
self.sub_GN = nn.Parameter(
|
|
torch.empty([1, 1, 1, self.image_dim_out * 4]))
|
|
|
|
dim_projection = hidden_size
|
|
depth = 2
|
|
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
|
|
for _ in range(1, depth):
|
|
layers.extend(
|
|
[nn.GELU(),
|
|
nn.Linear(dim_projection, dim_projection)])
|
|
self.img_projection = nn.Sequential(*layers)
|
|
|
|
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor,
|
|
image_sizes: torch.Tensor) -> torch.FloatTensor:
|
|
"""
|
|
process image and return vision embeddings.
|
|
|
|
pixel_values: (num_images, num_crops, c, h, w)
|
|
output: (num_images, num_img_tokens, hidden_size)
|
|
"""
|
|
num_images, num_crops, c, h, w = pixel_values.shape
|
|
pixel_values = pixel_values.flatten(0, 1)
|
|
img_features = self.get_img_features(pixel_values)
|
|
img_features = img_features.reshape(num_images, num_crops, -1,
|
|
self.image_dim_out)
|
|
image_features_proj = self.hd_feature_transform(
|
|
img_features, image_sizes)
|
|
return image_features_proj
|
|
|
|
def hd_feature_transform(self, image_features, image_sizes):
|
|
"""
|
|
image_features: (num_images, num_crops+1, 24*24, 1024)
|
|
"""
|
|
assert (
|
|
self.hd_transform_order == 'sub_glb'
|
|
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
|
|
if isinstance(self.img_projection, nn.Sequential):
|
|
target_device = self.img_projection[0].bias.device
|
|
target_dtype = self.img_projection[0].bias.dtype
|
|
else: # It's a single nn.Linear layer
|
|
target_device = self.img_projection.bias.device
|
|
target_dtype = self.img_projection.bias.dtype
|
|
|
|
global_image_features = image_features[:,
|
|
0] # (num_images, 24*24, 1024)
|
|
# global feature can be viewed as a special HD case with num_crops 1x1
|
|
global_image_features_hd = self.reshape_hd_patches_2x2merge(
|
|
global_image_features, 1, 1)
|
|
global_image_features_hd_newline = self.add_image_newline(
|
|
global_image_features_hd)
|
|
|
|
batch_image_features_proj = []
|
|
# need a for loop to process each image because of different image sizes
|
|
# (patch arrangement is different for each image)
|
|
for i, img_size in enumerate(image_sizes):
|
|
h, w = img_size
|
|
h_crop = h // 336
|
|
w_crop = w // 336
|
|
num_crops = h_crop * w_crop
|
|
|
|
# NOTE: real num_crops is padded
|
|
# (num_crops, 24*24, 1024)
|
|
sub_image_features = image_features[i, 1:1 + num_crops]
|
|
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
|
|
sub_image_features, h_crop, w_crop)
|
|
sub_image_features_hd_newline = self.add_image_newline(
|
|
sub_image_features_hd)
|
|
|
|
# [sub features, separator, global features]
|
|
image_embeddings = torch.cat([
|
|
sub_image_features_hd_newline.squeeze(
|
|
0), # (h_crop*12*(w_crop*12+1), 4096)
|
|
self.glb_GN.squeeze(0),
|
|
global_image_features_hd_newline[i],
|
|
])
|
|
img_proj = self.img_projection(
|
|
image_embeddings.to(target_device, target_dtype))
|
|
batch_image_features_proj.append(img_proj)
|
|
|
|
return batch_image_features_proj
|
|
|
|
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
|
|
"""
|
|
image_features: (num_images*num_crops, 24*24, 1024)
|
|
output: (num_images, h_crop*12, w_crop*12, 4096)
|
|
where h_crop*w_crop == num_crops
|
|
"""
|
|
N, L, C = image_features.shape
|
|
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
|
|
num_images = N // (h_crop * w_crop)
|
|
H = int(L**0.5)
|
|
image_features_hd = (
|
|
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
|
|
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
|
|
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
|
|
.reshape(N, -1, 4 * C) # N, 144, 4096
|
|
.reshape(num_images, h_crop, w_crop, H // 2, H // 2,
|
|
-1) # n_img, h_crop, w_crop, 12, 12, 4096
|
|
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
|
|
.reshape(num_images, h_crop * H // 2, w_crop * H // 2,
|
|
4 * C) # n_img, h_crop*12, w_crop*12, 4096
|
|
)
|
|
return image_features_hd
|
|
|
|
def add_image_newline(self, image_features_hd):
|
|
"""
|
|
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
|
|
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
|
|
"""
|
|
num_images, h, w, hid_dim = image_features_hd.shape
|
|
# add the newline token to the HD image feature patches
|
|
newline_embeddings = self.sub_GN.expand(num_images, h, -1,
|
|
-1) # (n_img, h, 1, hid_dim)
|
|
image_features_hd_newline = torch.cat(
|
|
[image_features_hd, newline_embeddings],
|
|
dim=2).reshape(num_images, -1, hid_dim)
|
|
return image_features_hd_newline
|
|
|
|
|
|
class Phi3VProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_hf_processor(
|
|
self,
|
|
*,
|
|
num_crops: Optional[int] = None,
|
|
) -> ProcessorMixin:
|
|
if num_crops is not None:
|
|
return self.ctx.get_hf_processor(num_crops=num_crops)
|
|
|
|
return self.ctx.get_hf_processor()
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": None}
|
|
|
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
max_image_tokens = self.get_num_image_tokens(
|
|
image_width=target_width,
|
|
image_height=target_height,
|
|
processor=None,
|
|
)
|
|
|
|
return {"image": max_image_tokens}
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
processor: Optional[ProcessorMixin],
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
return processor.calc_num_image_tokens_from_image_size( # type: ignore
|
|
width=image_width,
|
|
height=image_height,
|
|
)
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
# Result in the max possible feature size (h:w = 16:1)
|
|
return ImageSize(height=8000, width=50)
|
|
|
|
|
|
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
|
|
|
|
def get_dummy_processor_inputs(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> ProcessorInputs:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
target_width, target_height = \
|
|
self.info.get_image_size_with_most_features()
|
|
|
|
mm_data = {
|
|
"image":
|
|
self._get_dummy_images(width=target_width,
|
|
height=target_height,
|
|
num_images=num_images)
|
|
}
|
|
|
|
hf_processor = self.info.get_hf_processor()
|
|
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
|
|
|
return ProcessorInputs(
|
|
prompt_text="".join(image_tokens[:num_images]),
|
|
mm_data=mm_data,
|
|
)
|
|
|
|
|
|
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
)
|
|
|
|
input_ids = processed_outputs["input_ids"]
|
|
assert isinstance(input_ids, torch.Tensor)
|
|
|
|
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
|
# which will cause OverflowError when decoding the prompt_ids.
|
|
# Therefore, we need to do an early replacement here
|
|
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
|
|
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
|
image_sizes=MultiModalFieldConfig.batched("image"),
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
)
|
|
|
|
def _get_prompt_replacements(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> list[PromptReplacement]:
|
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
|
|
|
tokenizer = self.info.get_tokenizer()
|
|
bos_token_id = tokenizer.bos_token_id
|
|
assert isinstance(bos_token_id, int)
|
|
|
|
def get_replacement_phi3v(item_idx: int):
|
|
images = mm_items.get_items(
|
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
|
|
|
if isinstance(images, ImageEmbeddingItems):
|
|
num_image_tokens = images.get_feature_size(item_idx)
|
|
else:
|
|
image_size = images.get_image_size(item_idx)
|
|
num_image_tokens = self.info.get_num_image_tokens(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
processor=hf_processor,
|
|
)
|
|
|
|
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
|
|
|
num_images = mm_items.get_count("image", strict=False)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=image_token,
|
|
replacement=get_replacement_phi3v,
|
|
) for image_token in image_tokens[:num_images]
|
|
]
|
|
|
|
def _apply_prompt_replacements(
|
|
self,
|
|
token_ids: list[int],
|
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
|
mm_item_counts: Mapping[str, int],
|
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
|
|
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
|
token_ids=token_ids,
|
|
mm_prompt_repls=mm_prompt_repls,
|
|
mm_item_counts=mm_item_counts,
|
|
)
|
|
|
|
# Keep the behavior in line with HF processor
|
|
if text.startswith("<s> <|image|>"):
|
|
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
|
|
token_ids = [token_ids[0], *token_ids[2:]]
|
|
placeholders = {
|
|
modality: [
|
|
PlaceholderInfo(
|
|
modality=p.modality,
|
|
item_idx=p.item_idx,
|
|
start_idx=p.start_idx - 1,
|
|
replacement=p.replacement,
|
|
) for p in ps
|
|
]
|
|
for modality, ps in placeholders.items()
|
|
}
|
|
|
|
return token_ids, text, placeholders
|
|
|
|
def apply(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data: MultiModalDataDict,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> MultiModalInputsV2:
|
|
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
|
|
|
# Only <|image|> tokens should be considered as placeholders,
|
|
# so we ignore the trailing bos_token_id
|
|
result["mm_placeholders"] = {
|
|
modality: [
|
|
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
|
for p in ps
|
|
]
|
|
for modality, ps in result["mm_placeholders"].items()
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
|
|
info=Phi3VProcessingInfo,
|
|
dummy_inputs=Phi3VDummyInputsBuilder)
|
|
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"model.vision_embed_tokens.wte": "embed_tokens",
|
|
"model.vision_embed_tokens.": "vision_embed_tokens.",
|
|
"lm_head.": "language_model.lm_head.",
|
|
"model.": "language_model.model.",
|
|
})
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.image_token_id = _IMAGE_TOKEN_ID
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
|
)
|
|
|
|
# TODO: Optionally initializes this for supporting input embeddings.
|
|
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
|
config,
|
|
quant_config,
|
|
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
# The prefix is empty intentionally because default prefix of
|
|
# LlamaForCausalLM is "model"
|
|
prefix="",
|
|
# We don't directly initialize vLLM's LlamaForCausalLM so we
|
|
# can automatically apply embedding wrapper if this model is
|
|
# initialized as an embedding model
|
|
architectures=["LlamaForCausalLM"],
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
@cached_property
|
|
def sampler(self):
|
|
if hasattr(self.language_model, "sampler"):
|
|
return self.language_model.sampler
|
|
|
|
return get_sampler()
|
|
|
|
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
|
expected_dims = (2, )
|
|
|
|
def _validate_shape(d: torch.Tensor):
|
|
actual_dims = tuple(d.shape)
|
|
|
|
if actual_dims != expected_dims:
|
|
expected_expr = str(expected_dims)
|
|
raise ValueError(
|
|
f"The expected shape of image sizes per image per batch "
|
|
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
|
|
|
for d in data:
|
|
_validate_shape(d)
|
|
|
|
return data
|
|
|
|
def _validate_pixel_values(
|
|
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
|
expected_dims = (3, h, w)
|
|
|
|
def _validate_shape(d: torch.Tensor):
|
|
actual_dims = tuple(d.shape[1:])
|
|
|
|
if actual_dims != expected_dims:
|
|
expected_expr = ("num_patches", *map(str, expected_dims))
|
|
raise ValueError(
|
|
"The expected shape of pixel values per image per batch "
|
|
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
|
|
|
for d in data:
|
|
_validate_shape(d)
|
|
|
|
return data
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
if pixel_values is not None:
|
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
if not isinstance(image_sizes, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of image sizes. "
|
|
f"Got type: {type(image_sizes)}")
|
|
|
|
return Phi3VImagePixelInputs(
|
|
type="pixel_values",
|
|
data=self._validate_pixel_values(flatten_bn(pixel_values)),
|
|
image_sizes=self._validate_image_sizes(
|
|
flatten_bn(image_sizes, concat=True)))
|
|
|
|
if image_embeds is not None:
|
|
if not isinstance(image_embeds, torch.Tensor):
|
|
raise ValueError("Incorrect type of image embeddings. "
|
|
f"Got type: {type(image_embeds)}")
|
|
|
|
return Phi3VImageEmbeddingInputs(
|
|
type="image_embeds",
|
|
data=flatten_bn(image_embeds),
|
|
)
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: Phi3VImageInputs,
|
|
) -> torch.Tensor:
|
|
|
|
if image_input["type"] == "image_embeds":
|
|
image_data = image_input["data"]
|
|
if is_list_of(image_data, torch.Tensor):
|
|
# it's already a list of tensors
|
|
return image_data
|
|
if len(image_data.shape) == 3:
|
|
# 3D tensor
|
|
return list(torch.unbind(image_data, dim=0))
|
|
raise ValueError(
|
|
"We expect batched 2D tensors;"
|
|
"this can be either a list of 2D tensors or a single 3D tensor."
|
|
)
|
|
|
|
assert self.vision_embed_tokens is not None
|
|
image_embeds = self.vision_embed_tokens(image_input["data"],
|
|
image_input["image_sizes"])
|
|
|
|
return image_embeds
|
|
|
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
return vision_embeddings
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, multimodal_embeddings,
|
|
self.image_token_id)
|
|
return inputs_embeds
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object):
|
|
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
kv_caches,
|
|
attn_metadata,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
|
|
loader = AutoWeightsLoader(self)
|
|
autoloaded_weights = loader.load_weights(weights,
|
|
mapper=self.hf_to_vllm_mapper)
|
|
|
|
# The HF config doesn't specify whether these are tied,
|
|
# so we detect it this way
|
|
if "embed_tokens.weight" not in autoloaded_weights:
|
|
self.embed_tokens = self.language_model.model.embed_tokens
|
|
autoloaded_weights.add("embed_tokens.weight")
|
|
return autoloaded_weights
|