[Bugfix][Model] Add base class for vision-language models (#4809)

This commit is contained in:
Cyrus Leung 2024-05-19 15:13:33 +08:00 committed by GitHub
parent 2e9a2227ec
commit f68470e803
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 29 deletions

View File

@ -0,0 +1,9 @@
import pytest
from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
# Ensure all model classes can be imported successfully
ModelRegistry.load_model_cls(model_cls)

View File

@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
logger = init_logger(__name__)
@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
extra_kwargs["vision_language_config"] = vision_language_config
return extra_kwargs

View File

@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
text_hidden_size,
bias=True)
def forward(self, image_features):
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
@ -50,30 +52,32 @@ class LlavaMultiModalProjector(nn.Module):
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int):
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds
class LlavaForConditionalGeneration(nn.Module):
class LlavaForConditionalGeneration(VisionLanguageModelBase):
def __init__(self,
config: "LlavaConfig",
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
self.config = config
self.vision_language_config = vision_language_config
assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None
) -> SamplerOutput: # noqa: E501
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
_merge_vision_embeddings(
inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None

View File

@ -0,0 +1,12 @@
from torch import nn
from vllm.config import VisionLanguageConfig
class VisionLanguageModelBase(nn.Module):
"""Base class for all vision language models (VLMs)."""
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
super().__init__()
self.vision_language_config = vision_language_config