[Bugfix][Model] Add base class for vision-language models (#4809)
This commit is contained in:
parent
2e9a2227ec
commit
f68470e803
9
tests/models/test_registry.py
Normal file
9
tests/models/test_registry.py
Normal file
@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.models import _MODELS, ModelRegistry
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||
def test_registry_imports(model_cls):
|
||||
# Ensure all model classes can be imported successfully
|
||||
ModelRegistry.load_model_cls(model_cls)
|
@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||
|
||||
_VISION_MODEL_CLASSES = [
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
elif model_class in _VISION_MODEL_CLASSES:
|
||||
elif issubclass(model_class, VisionLanguageModelBase):
|
||||
if vision_language_config is None:
|
||||
raise ValueError("Provide `image_input_type` and other vision "
|
||||
"related configurations through LLM entrypoint "
|
||||
"or engine arguments.")
|
||||
|
||||
extra_kwargs["vision_language_config"] = vision_language_config
|
||||
return extra_kwargs
|
||||
|
||||
|
@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
"language_model.model": "language_model",
|
||||
@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
text_hidden_size,
|
||||
bias=True)
|
||||
|
||||
def forward(self, image_features):
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
@ -50,30 +52,32 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int):
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = (input_ids == image_token_id)
|
||||
inputs_embeds[mask] = vision_embeddings.view(-1,
|
||||
|
||||
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
|
||||
if mask.sum() != image_feature_size:
|
||||
raise ValueError(f"image_feature_size should be {image_feature_size}, "
|
||||
f"but found: {mask.sum()}")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
|
||||
vision_embeddings.shape[-1])
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
class LlavaForConditionalGeneration(nn.Module):
|
||||
|
||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
def __init__(self,
|
||||
config: "LlavaConfig",
|
||||
config: LlavaConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional["QuantizationConfig"] = None) -> None:
|
||||
super().__init__()
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
assert self.vision_language_config, (
|
||||
"Provide `image_input_type` and other vision "
|
||||
"related configurations through LLM entrypoint "
|
||||
"or engine arguments.")
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
image_input: Optional[torch.Tensor] = None
|
||||
) -> SamplerOutput: # noqa: E501
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
|
||||
"""Run forward pass for Llava 1.5.
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
image_features = image_input
|
||||
vision_embeddings = self.multi_modal_projector(image_features)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
_merge_vision_embeddings(
|
||||
inputs_embeds = _merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
input_ids = None
|
||||
|
12
vllm/model_executor/models/vlm_base.py
Normal file
12
vllm/model_executor/models/vlm_base.py
Normal file
@ -0,0 +1,12 @@
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
|
||||
class VisionLanguageModelBase(nn.Module):
|
||||
"""Base class for all vision language models (VLMs)."""
|
||||
|
||||
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_language_config = vision_language_config
|
Loading…
x
Reference in New Issue
Block a user