2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-02-25 01:13:52 +00:00
|
|
|
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
2024-08-15 01:55:42 +08:00
|
|
|
TypedDict, Union)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
2024-08-05 15:22:12 +09:00
|
|
|
from transformers import PaliGemmaConfig
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-11-08 22:17:28 -08:00
|
|
|
from vllm.config import VllmConfig
|
2024-11-01 16:21:10 -07:00
|
|
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
|
|
InputContext, token_inputs)
|
2024-07-06 18:25:50 -07:00
|
|
|
from vllm.logger import init_logger
|
2024-10-03 19:56:58 -07:00
|
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
2024-07-06 18:25:50 -07:00
|
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
2024-11-26 12:46:11 -08:00
|
|
|
from vllm.multimodal.inputs import NestedTensors
|
2024-08-29 22:19:08 -04:00
|
|
|
from vllm.sequence import IntermediateTensors
|
2025-02-19 21:13:50 +08:00
|
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-10-03 19:56:58 -07:00
|
|
|
from .interfaces import SupportsMultiModal, SupportsPP
|
2024-08-05 15:22:12 +09:00
|
|
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
|
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
2024-11-08 22:17:28 -08:00
|
|
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
2024-11-10 22:41:46 -08:00
|
|
|
maybe_prefix, merge_multimodal_embeddings)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
class PaliGemmaImagePixelInputs(TypedDict):
|
|
|
|
type: Literal["pixel_values"]
|
|
|
|
data: torch.Tensor
|
2024-08-28 23:11:18 +08:00
|
|
|
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
2024-08-12 01:16:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaImageEmbeddingInputs(TypedDict):
|
|
|
|
type: Literal["image_embeds"]
|
|
|
|
data: torch.Tensor
|
2024-08-28 23:11:18 +08:00
|
|
|
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
2024-08-12 01:16:06 -07:00
|
|
|
|
|
|
|
`hidden_size` must match the hidden size of language model backbone.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
|
|
|
PaliGemmaImageEmbeddingInputs]
|
|
|
|
|
|
|
|
|
2024-07-06 18:25:50 -07:00
|
|
|
def get_max_paligemma_image_tokens(ctx: InputContext):
|
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
2024-08-05 15:22:12 +09:00
|
|
|
vision_config = hf_config.vision_config
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-05 15:22:12 +09:00
|
|
|
return get_max_siglip_image_tokens(vision_config)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
|
2024-08-15 01:55:42 +08:00
|
|
|
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
|
|
|
mm_counts: Mapping[str, int]):
|
2024-07-06 18:25:50 -07:00
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
vision_config = hf_config.vision_config
|
2024-08-15 01:55:42 +08:00
|
|
|
num_images = mm_counts["image"]
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-11-01 16:21:10 -07:00
|
|
|
seq_data, ranges = dummy_seq_data_for_siglip(
|
2024-08-05 15:22:12 +09:00
|
|
|
vision_config,
|
2024-07-06 18:25:50 -07:00
|
|
|
seq_len,
|
2024-08-15 01:55:42 +08:00
|
|
|
num_images,
|
2024-07-06 18:25:50 -07:00
|
|
|
image_token_id=hf_config.image_token_index,
|
|
|
|
)
|
|
|
|
|
2024-08-15 01:55:42 +08:00
|
|
|
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
2024-11-01 16:21:10 -07:00
|
|
|
return DummyData(seq_data, mm_data, ranges)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
|
2024-10-16 18:49:37 +08:00
|
|
|
def input_processor_for_paligemma(ctx: InputContext,
|
|
|
|
inputs: DecoderOnlyInputs):
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
"""
|
|
|
|
The correct prompt format needs to be:
|
|
|
|
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
|
|
|
|
|
|
|
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
|
|
|
""" # noqa
|
|
|
|
|
2024-10-16 18:49:37 +08:00
|
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
2024-07-06 18:25:50 -07:00
|
|
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
2024-10-16 18:49:37 +08:00
|
|
|
return inputs
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
model_config = ctx.model_config
|
|
|
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
|
|
|
|
2025-02-19 21:13:50 +08:00
|
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
2024-07-06 18:25:50 -07:00
|
|
|
image_feature_size = hf_config.text_config.num_image_tokens
|
|
|
|
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
|
|
|
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
|
|
|
image_token_str_pad = image_token_str * image_feature_size
|
|
|
|
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
|
|
|
|
2024-10-16 18:49:37 +08:00
|
|
|
orig_prompt = inputs.get("prompt")
|
|
|
|
orig_prompt_ids = inputs.get("prompt_token_ids")
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-07-12 23:22:18 +08:00
|
|
|
if orig_prompt is not None and image_token_str in orig_prompt:
|
2024-07-06 18:25:50 -07:00
|
|
|
logger.warning(
|
|
|
|
"The image token '%s' was detected in the prompt and "
|
|
|
|
"will be removed. Please follow the proper prompt format"
|
|
|
|
" documented on HuggingFace.", image_token_str)
|
|
|
|
orig_prompt = orig_prompt.replace(image_token_str, "")
|
|
|
|
orig_prompt_ids.remove(hf_config.image_token_index)
|
|
|
|
|
|
|
|
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
2024-12-13 09:40:07 +02:00
|
|
|
|
|
|
|
# The PaliGemma 2 tokenizer does not include a starting BOS token
|
|
|
|
if orig_prompt_ids[0] != hf_config.bos_token_id:
|
|
|
|
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
|
|
|
|
2024-07-06 18:25:50 -07:00
|
|
|
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
|
|
|
|
|
|
|
# NOTE: Create a defensive copy of the original inputs
|
2024-10-16 18:49:37 +08:00
|
|
|
return token_inputs(prompt_token_ids=new_token_ids,
|
|
|
|
prompt=new_prompt,
|
|
|
|
multi_modal_data=multi_modal_data)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaMultiModalProjector(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
|
|
|
super().__init__()
|
|
|
|
|
2024-07-30 02:20:57 -07:00
|
|
|
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
2024-07-30 02:20:57 -07:00
|
|
|
hidden_states = self.linear(image_features)
|
2024-07-06 18:25:50 -07:00
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
|
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
|
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
|
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
2024-10-03 19:56:58 -07:00
|
|
|
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
|
|
SupportsPP):
|
2025-01-21 15:49:08 +08:00
|
|
|
packed_modules_mapping = {
|
|
|
|
"qkv_proj": [
|
|
|
|
"q_proj",
|
|
|
|
"k_proj",
|
|
|
|
"v_proj",
|
|
|
|
],
|
|
|
|
"gate_up_proj": [
|
|
|
|
"gate_proj",
|
|
|
|
"up_proj",
|
|
|
|
],
|
|
|
|
}
|
2025-01-21 17:15:27 -08:00
|
|
|
|
2024-11-10 22:41:46 -08:00
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
2024-07-06 18:25:50 -07:00
|
|
|
super().__init__()
|
2024-11-08 22:17:28 -08:00
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
quant_config = vllm_config.quant_config
|
|
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
2024-07-06 18:25:50 -07:00
|
|
|
self.config = config
|
|
|
|
self.multimodal_config = multimodal_config
|
|
|
|
|
2024-10-23 19:27:37 +08:00
|
|
|
self.vision_tower = SiglipVisionModel(config.vision_config,
|
2024-10-29 19:02:59 -04:00
|
|
|
quant_config,
|
2024-11-10 22:41:46 -08:00
|
|
|
prefix=maybe_prefix(
|
|
|
|
prefix, "vision_tower"))
|
2024-07-06 18:25:50 -07:00
|
|
|
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
|
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
|
|
projection_dim=config.vision_config.projection_dim)
|
|
|
|
|
|
|
|
self.quant_config = quant_config
|
2024-12-13 09:40:07 +02:00
|
|
|
|
|
|
|
if config.text_config.model_type == "gemma":
|
|
|
|
config.text_config.architectures = ["GemmaForCausalLM"]
|
|
|
|
else:
|
|
|
|
config.text_config.architectures = ["Gemma2ForCausalLM"]
|
2024-11-08 22:17:28 -08:00
|
|
|
self.language_model = init_vllm_registered_model(
|
|
|
|
vllm_config=vllm_config,
|
2024-12-01 08:02:54 +08:00
|
|
|
hf_config=config.text_config,
|
|
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
|
|
)
|
2024-07-06 18:25:50 -07:00
|
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
2024-10-03 19:56:58 -07:00
|
|
|
self.language_model.logits_processor.scale *= logit_scale
|
|
|
|
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sampler(self):
|
|
|
|
return self.language_model.sampler
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
|
|
h = w = self.config.vision_config.image_size
|
|
|
|
expected_dims = (3, h, w)
|
|
|
|
actual_dims = tuple(data.shape[1:])
|
|
|
|
|
|
|
|
if actual_dims != expected_dims:
|
|
|
|
expected_expr = ("batch_size", *map(str, expected_dims))
|
|
|
|
raise ValueError(
|
|
|
|
f"The expected shape of pixel values is {expected_expr}. "
|
|
|
|
f"You supplied {tuple(data.shape)}.")
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
|
|
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
2024-08-12 01:16:06 -07:00
|
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
if pixel_values is None and image_embeds is None:
|
2024-07-06 18:25:50 -07:00
|
|
|
return None
|
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
if pixel_values is not None:
|
|
|
|
if not isinstance(pixel_values, torch.Tensor):
|
|
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
|
|
f"Got type: {type(pixel_values)}")
|
2024-08-27 18:53:56 -07:00
|
|
|
|
|
|
|
# Remove the N dimension until multiple images are supported.
|
|
|
|
pixel_values = pixel_values.squeeze(1)
|
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
return PaliGemmaImagePixelInputs(
|
|
|
|
type="pixel_values",
|
|
|
|
data=self._validate_pixel_values(pixel_values),
|
|
|
|
)
|
|
|
|
|
|
|
|
if image_embeds is not None:
|
|
|
|
if not isinstance(image_embeds, torch.Tensor):
|
|
|
|
raise ValueError("Incorrect type of image embeddings. "
|
|
|
|
f"Got type: {type(image_embeds)}")
|
2024-08-27 18:53:56 -07:00
|
|
|
|
|
|
|
# Remove the N dimension until multiple images are supported.
|
|
|
|
image_embeds = image_embeds.squeeze(1)
|
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
return PaliGemmaImageEmbeddingInputs(
|
|
|
|
type="image_embeds",
|
|
|
|
data=image_embeds,
|
|
|
|
)
|
|
|
|
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-05 15:22:12 +09:00
|
|
|
def _image_pixels_to_features(
|
|
|
|
self,
|
|
|
|
vision_tower: SiglipVisionModel,
|
|
|
|
pixel_values: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-07-12 23:22:18 +08:00
|
|
|
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
2024-08-05 15:22:12 +09:00
|
|
|
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-05 15:22:12 +09:00
|
|
|
return image_features
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
def _process_image_input(
|
2024-08-05 15:22:12 +09:00
|
|
|
self,
|
2024-08-12 01:16:06 -07:00
|
|
|
image_input: PaliGemmaImageInputs,
|
2024-08-05 15:22:12 +09:00
|
|
|
) -> torch.Tensor:
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
if image_input["type"] == "image_embeds":
|
|
|
|
return image_input["data"]
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-08-12 01:16:06 -07:00
|
|
|
assert self.vision_tower is not None
|
|
|
|
pixel_values = image_input["data"]
|
|
|
|
image_features = self._image_pixels_to_features(
|
2024-08-05 15:22:12 +09:00
|
|
|
self.vision_tower,
|
|
|
|
pixel_values,
|
|
|
|
)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
return self.multi_modal_projector(image_features)
|
|
|
|
|
2024-11-26 12:46:11 -08:00
|
|
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
|
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
|
|
if image_input is None:
|
|
|
|
return None
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
|
|
|
|
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
|
|
|
|
return vision_embeddings
|
|
|
|
|
|
|
|
def get_input_embeddings(
|
|
|
|
self,
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
|
|
if multimodal_embeddings is not None:
|
|
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
|
|
input_ids, inputs_embeds, multimodal_embeddings,
|
|
|
|
self.config.image_token_index)
|
|
|
|
return inputs_embeds
|
|
|
|
|
2024-07-12 23:22:18 +08:00
|
|
|
def forward(self,
|
|
|
|
input_ids: torch.Tensor,
|
|
|
|
positions: torch.Tensor,
|
|
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
2024-11-26 12:46:11 -08:00
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
2024-10-03 19:56:58 -07:00
|
|
|
**kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
|
|
|
|
if intermediate_tensors is not None:
|
|
|
|
inputs_embeds = None
|
2024-11-26 12:46:11 -08:00
|
|
|
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
|
|
# condition is for v0 compatibility.
|
|
|
|
elif inputs_embeds is None:
|
|
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
|
|
vision_embeddings)
|
|
|
|
input_ids = None
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-09-08 01:45:44 +08:00
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
|
|
positions,
|
2024-10-03 19:56:58 -07:00
|
|
|
intermediate_tensors,
|
2024-09-08 01:45:44 +08:00
|
|
|
inputs_embeds=inputs_embeds)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
2024-08-13 13:33:41 +08:00
|
|
|
def compute_logits(
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
) -> Optional[torch.Tensor]:
|
2024-09-08 01:45:44 +08:00
|
|
|
return self.language_model.compute_logits(hidden_states,
|
|
|
|
sampling_metadata)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
logits: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
) -> Optional[SamplerOutput]:
|
2024-09-08 01:45:44 +08:00
|
|
|
return self.language_model.sample(logits, sampling_metadata)
|
2024-07-06 18:25:50 -07:00
|
|
|
|
2024-11-18 09:07:46 +08:00
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
|
|
torch.Tensor]]) -> Set[str]:
|
2024-10-09 15:36:55 +08:00
|
|
|
loader = AutoWeightsLoader(self)
|
2024-11-18 09:07:46 +08:00
|
|
|
return loader.load_weights(weights)
|