[Model] Add PaliGemma (#5189)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Roger Wang 2024-07-06 18:25:50 -07:00 committed by GitHub
parent 9389380015
commit 6206dcb29e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 557 additions and 2 deletions

View File

@ -186,6 +186,10 @@ Vision Language Models
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.

View File

@ -0,0 +1,52 @@
import os
import subprocess
from PIL import Image
from vllm import LLM
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_paligemma():
llm = LLM(model="google/paligemma-3b-mix-224")
prompt = "caption es"
image = Image.open("images/stop_sign.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main():
run_paligemma()
if __name__ == "__main__":
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
main()

View File

@ -0,0 +1,147 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "caption es",
"cherry_blossom": "What is in the picture?",
"boardwalk": "What is in the picture?",
})
IMAGE_TOKEN_ID = 257152
models = ["google/paligemma-3b-mix-224"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -49,6 +49,8 @@ _GENERATION_MODELS = {
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PaliGemmaForConditionalGeneration":
("paligemma", "PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),

View File

@ -268,16 +268,22 @@ class GemmaModel(nn.Module):
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]

View File

@ -0,0 +1,344 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
from PIL import Image
from torch import nn
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model",
}
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
text_config = hf_config.text_config
return text_config.num_image_tokens
def dummy_seq_data_for_paligemma(
hf_config: PaliGemmaConfig,
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = hf_config.text_config.num_image_tokens
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_paligemma(
hf_config: SiglipVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
seq_data = dummy_seq_data_for_paligemma(
hf_config,
seq_len,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_paligemma(vision_config)
return seq_data, mm_data
def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
orig_prompt = llm_inputs.get("prompt")
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
if image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
return hidden_states
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: PaliGemmaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# TODO(ywang96): Port over SiglipVisionModel & TP
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config
self.language_model = GemmaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = vision_tower(pixel_values, output_hidden_states=True)
selected_image_features = image_outputs.last_hidden_state
return selected_image_features
def _process_image_pixels(
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object) -> SamplerOutput:
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
vision_embeddings = vision_embeddings * (self.config.hidden_size**
-0.5)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
# Copied from vllm/model_executor/models/gemma.py
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.language_model.embed_tokens,
hidden_states, sampling_metadata)
return logits
# Copied from vllm/model_executor/models/gemma.py
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# Adapted from vllm/model_executor/models/gemma.py
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")