[Model] Add PaliGemma (#5189)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
9389380015
commit
6206dcb29e
@ -186,6 +186,10 @@ Vision Language Models
|
|||||||
- LLaVA-NeXT
|
- LLaVA-NeXT
|
||||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`PaliGemmaForConditionalGeneration`
|
||||||
|
- PaliGemma
|
||||||
|
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
||||||
|
-
|
||||||
* - :code:`Phi3VForCausalLM`
|
* - :code:`Phi3VForCausalLM`
|
||||||
- Phi-3-Vision
|
- Phi-3-Vision
|
||||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
|
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
|
||||||
|
52
examples/paligemma_example.py
Normal file
52
examples/paligemma_example.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||||
|
# You can use `.buildkite/download-images.sh` to download them
|
||||||
|
|
||||||
|
|
||||||
|
def run_paligemma():
|
||||||
|
llm = LLM(model="google/paligemma-3b-mix-224")
|
||||||
|
|
||||||
|
prompt = "caption es"
|
||||||
|
|
||||||
|
image = Image.open("images/stop_sign.jpg")
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": image
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
run_paligemma()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Download from s3
|
||||||
|
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
||||||
|
local_directory = "images"
|
||||||
|
|
||||||
|
# Make sure the local directory exists or create it
|
||||||
|
os.makedirs(local_directory, exist_ok=True)
|
||||||
|
|
||||||
|
# Use AWS CLI to sync the directory, assume anonymous access
|
||||||
|
subprocess.check_call([
|
||||||
|
"aws",
|
||||||
|
"s3",
|
||||||
|
"sync",
|
||||||
|
s3_bucket_path,
|
||||||
|
local_directory,
|
||||||
|
"--no-sign-request",
|
||||||
|
])
|
||||||
|
main()
|
147
tests/models/test_paligemma.py
Normal file
147
tests/models/test_paligemma.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign": "caption es",
|
||||||
|
"cherry_blossom": "What is in the picture?",
|
||||||
|
"boardwalk": "What is in the picture?",
|
||||||
|
})
|
||||||
|
|
||||||
|
IMAGE_TOKEN_ID = 257152
|
||||||
|
|
||||||
|
models = ["google/paligemma-3b-mix-224"]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
Optional[SampleLogprobs]],
|
||||||
|
model: str):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
hf_output_ids = [
|
||||||
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
|
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_output_str = output_str
|
||||||
|
|
||||||
|
if hf_output_ids[-1] == eos_token_id:
|
||||||
|
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||||
|
|
||||||
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding vision language config as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_image = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||||
|
hf_outputs_per_image = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
|
vllm_outputs_per_image):
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output, model)
|
||||||
|
for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
@ -49,6 +49,8 @@ _GENERATION_MODELS = {
|
|||||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||||
|
"PaliGemmaForConditionalGeneration":
|
||||||
|
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
|
@ -268,16 +268,22 @@ class GemmaModel(nn.Module):
|
|||||||
normalizer = self.config.hidden_size**0.5
|
normalizer = self.config.hidden_size**0.5
|
||||||
self.register_buffer("normalizer", torch.tensor(normalizer))
|
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
hidden_states *= self.normalizer
|
hidden_states *= self.normalizer
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
|
344
vllm/model_executor/models/paligemma.py
Normal file
344
vllm/model_executor/models/paligemma.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.gemma import GemmaModel
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData
|
||||||
|
|
||||||
|
from .interfaces import SupportsVision
|
||||||
|
from .utils import merge_vision_embeddings
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
|
"language_model.model": "language_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||||
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
|
text_config = hf_config.text_config
|
||||||
|
|
||||||
|
return text_config.num_image_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_seq_data_for_paligemma(
|
||||||
|
hf_config: PaliGemmaConfig,
|
||||||
|
seq_len: int,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if image_feature_size_override is None:
|
||||||
|
image_feature_size = hf_config.text_config.num_image_tokens
|
||||||
|
else:
|
||||||
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
token_ids = [image_token_id] * image_feature_size
|
||||||
|
token_ids += [0] * (seq_len - image_feature_size)
|
||||||
|
return SequenceData(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_image_for_paligemma(
|
||||||
|
hf_config: SiglipVisionConfig,
|
||||||
|
*,
|
||||||
|
image_width_override: Optional[int] = None,
|
||||||
|
image_height_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
width = height = hf_config.image_size
|
||||||
|
if image_width_override is not None:
|
||||||
|
width = image_width_override
|
||||||
|
if image_height_override is not None:
|
||||||
|
height = image_height_override
|
||||||
|
|
||||||
|
image = Image.new("RGB", (width, height), color=0)
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
|
||||||
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
seq_data = dummy_seq_data_for_paligemma(
|
||||||
|
hf_config,
|
||||||
|
seq_len,
|
||||||
|
image_token_id=hf_config.image_token_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = dummy_image_for_paligemma(vision_config)
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
The correct prompt format needs to be:
|
||||||
|
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
||||||
|
|
||||||
|
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
||||||
|
""" # noqa
|
||||||
|
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
model_config = ctx.model_config
|
||||||
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
|
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
image_feature_size = hf_config.text_config.num_image_tokens
|
||||||
|
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
||||||
|
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
||||||
|
image_token_str_pad = image_token_str * image_feature_size
|
||||||
|
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
||||||
|
|
||||||
|
orig_prompt = llm_inputs.get("prompt")
|
||||||
|
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
||||||
|
|
||||||
|
if image_token_str in orig_prompt:
|
||||||
|
logger.warning(
|
||||||
|
"The image token '%s' was detected in the prompt and "
|
||||||
|
"will be removed. Please follow the proper prompt format"
|
||||||
|
" documented on HuggingFace.", image_token_str)
|
||||||
|
orig_prompt = orig_prompt.replace(image_token_str, "")
|
||||||
|
orig_prompt_ids.remove(hf_config.image_token_index)
|
||||||
|
|
||||||
|
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
||||||
|
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
||||||
|
|
||||||
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
|
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear = ColumnParallelLinear(vision_hidden_size,
|
||||||
|
projection_dim,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.linear(image_features)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: (batch_size, num_channels, height, width)"""
|
||||||
|
|
||||||
|
|
||||||
|
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||||
|
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PaliGemmaConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
# TODO(ywang96): Port over SiglipVisionModel & TP
|
||||||
|
self.vision_tower = SiglipVisionModel(config.vision_config)
|
||||||
|
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||||
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
|
projection_dim=config.vision_config.projection_dim)
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.language_model = GemmaModel(config.text_config, cache_config,
|
||||||
|
quant_config)
|
||||||
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size, logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = w = self.config.vision_config.image_size
|
||||||
|
expected_dims = (3, h, w)
|
||||||
|
actual_dims = tuple(data.shape[1:])
|
||||||
|
|
||||||
|
if actual_dims != expected_dims:
|
||||||
|
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||||
|
raise ValueError(
|
||||||
|
f"The expected shape of pixel values is {expected_expr}. "
|
||||||
|
f"You supplied {tuple(data.shape)}.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
return PaliGemmaImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
||||||
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
image_outputs = vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
|
||||||
|
selected_image_features = image_outputs.last_hidden_state
|
||||||
|
|
||||||
|
return selected_image_features
|
||||||
|
|
||||||
|
def _process_image_pixels(
|
||||||
|
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
|
pixel_values = inputs["data"]
|
||||||
|
|
||||||
|
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
image_features = self._process_image_pixels(image_input)
|
||||||
|
|
||||||
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
**kwargs: object) -> SamplerOutput:
|
||||||
|
|
||||||
|
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
|
if parsed_image_input is not None:
|
||||||
|
vision_embeddings = self._process_image_input(parsed_image_input)
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
|
||||||
|
vision_embeddings = vision_embeddings * (self.config.hidden_size**
|
||||||
|
-0.5)
|
||||||
|
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = merge_vision_embeddings(
|
||||||
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
|
self.config.image_token_index)
|
||||||
|
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model(input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
# Copied from vllm/model_executor/models/gemma.py
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.language_model.embed_tokens,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Copied from vllm/model_executor/models/gemma.py
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
# Adapted from vllm/model_executor/models/gemma.py
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||||
|
if key_to_modify in name:
|
||||||
|
name = name.replace(key_to_modify, new_key)
|
||||||
|
use_default_weight_loading = False
|
||||||
|
if "vision" in name:
|
||||||
|
if self.vision_tower is not None:
|
||||||
|
# We only do sharding for language model and
|
||||||
|
# not vision model for now.
|
||||||
|
use_default_weight_loading = True
|
||||||
|
else:
|
||||||
|
for (param_name, shard_name,
|
||||||
|
shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# lm_head is not used in vllm as it is tied with
|
||||||
|
# embed_token. To prevent errors, skip loading
|
||||||
|
# lm_head.weight.
|
||||||
|
if "lm_head.weight" in name:
|
||||||
|
continue
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
use_default_weight_loading = True
|
||||||
|
|
||||||
|
if use_default_weight_loading:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
if unloaded_params:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Some weights are not initialized from checkpoints: "
|
||||||
|
f"{unloaded_params}")
|
Loading…
x
Reference in New Issue
Block a user