[VLM] Cleanup siglip legacy code and fix broken paligemma multimodal processor (#14602)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
70b808fe1a
commit
1477ffc381
@ -24,9 +24,10 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
def get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
@ -78,9 +82,8 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
return vision_encoder_info.get_max_image_tokens()
|
||||
|
||||
|
||||
class PaliGemmaDummyInputsBuilder(
|
||||
@ -173,8 +176,10 @@ class PaliGemmaMultiModalProcessor(
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
return_mm_hashes)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
@ -6,7 +6,6 @@ import math
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
|
||||
@ -20,74 +19,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
# assert image_size % patch_size == 0
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_siglip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
|
||||
return get_siglip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
|
||||
return get_siglip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_siglip(
|
||||
hf_config: SiglipVisionConfig,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
mm_key: str = "image",
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
), {
|
||||
mm_key:
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
|
||||
|
||||
def dummy_image_for_siglip(
|
||||
hf_config: SiglipVisionConfig,
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
@ -96,10 +31,10 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_siglip_image_feature_size(self.vision_config)
|
||||
return self.get_patch_grid_length()**2
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_siglip_image_tokens(self.vision_config)
|
||||
return self.get_patch_grid_length()**2
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
@ -108,10 +43,8 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
return self.vision_config.patch_size
|
||||
|
||||
def get_patch_grid_length(self) -> int:
|
||||
return get_siglip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
image_size, patch_size = self.get_image_size(), self.get_patch_size()
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
|
Loading…
x
Reference in New Issue
Block a user