63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Final, Generic, Protocol, TypeVar
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
|
|
|
|
|
class VisionEncoderInfo(ABC, Generic[_C]):
|
|
|
|
def __init__(self, vision_config: _C) -> None:
|
|
super().__init__()
|
|
|
|
self.vision_config = vision_config
|
|
|
|
@abstractmethod
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_max_image_tokens(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_image_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_grid_length(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
|
|
class VisionLanguageConfig(Protocol):
|
|
vision_config: Final[PretrainedConfig]
|
|
|
|
|
|
def get_vision_encoder_info(
|
|
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
|
# Avoid circular imports
|
|
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
|
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
|
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
|
|
|
vision_config = hf_config.vision_config
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
return CLIPEncoderInfo(vision_config)
|
|
if isinstance(vision_config, PixtralVisionConfig):
|
|
return PixtralHFEncoderInfo(vision_config)
|
|
if isinstance(vision_config, SiglipVisionConfig):
|
|
return SiglipEncoderInfo(vision_config)
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|