[VLM] Cleanup siglip legacy code and fix broken paligemma multimodal processor (#14602)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-11 19:27:36 +08:00 committed by GitHub
parent 70b808fe1a
commit 1477ffc381
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 76 deletions

View File

@ -24,9 +24,10 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
logger = init_logger(__name__) logger = init_logger(__name__)
@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig) return self.ctx.get_hf_config(PaliGemmaConfig)
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
@ -78,9 +82,8 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
return {"image": self.get_num_image_tokens()} return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info()
vision_config = hf_config.vision_config return vision_encoder_info.get_max_image_tokens()
return get_max_siglip_image_tokens(vision_config)
class PaliGemmaDummyInputsBuilder( class PaliGemmaDummyInputsBuilder(
@ -173,8 +176,10 @@ class PaliGemmaMultiModalProcessor(
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs) mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()

View File

@ -6,7 +6,6 @@ import math
from typing import Iterable, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
@ -20,74 +19,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_siglip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
return get_siglip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
return get_siglip_image_feature_size(hf_config)
def dummy_seq_data_for_siglip(
hf_config: SiglipVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image",
):
if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_siglip(
hf_config: SiglipVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_num_image_tokens( def get_num_image_tokens(
@ -96,10 +31,10 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> int: ) -> int:
return get_siglip_image_feature_size(self.vision_config) return self.get_patch_grid_length()**2
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config) return self.get_patch_grid_length()**2
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size
@ -108,10 +43,8 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
return self.vision_config.patch_size return self.vision_config.patch_size
def get_patch_grid_length(self) -> int: def get_patch_grid_length(self) -> int:
return get_siglip_patch_grid_length( image_size, patch_size = self.get_image_size(), self.get_patch_size()
image_size=self.vision_config.image_size, return image_size // patch_size
patch_size=self.vision_config.patch_size,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa