[Bugfix][Model] Add base class for vision-language models (#4809)
This commit is contained in:
parent
2e9a2227ec
commit
f68470e803
9
tests/models/test_registry.py
Normal file
9
tests/models/test_registry.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.models import _MODELS, ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_cls", _MODELS)
|
||||||
|
def test_registry_imports(model_cls):
|
||||||
|
# Ensure all model classes can be imported successfully
|
||||||
|
ModelRegistry.load_model_cls(model_cls)
|
@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
download_weights_from_hf, filter_files_not_needed_for_inference,
|
download_weights_from_hf, filter_files_not_needed_for_inference,
|
||||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||||
pt_weights_iterator, safetensors_weights_iterator)
|
pt_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||||
|
|
||||||
_VISION_MODEL_CLASSES = [
|
|
||||||
LlavaForConditionalGeneration,
|
|
||||||
]
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
|
|||||||
"but LoRA is enabled. Support for this model may "
|
"but LoRA is enabled. Support for this model may "
|
||||||
"be added in the future. If this is important to you, "
|
"be added in the future. If this is important to you, "
|
||||||
"please open an issue on github.")
|
"please open an issue on github.")
|
||||||
elif model_class in _VISION_MODEL_CLASSES:
|
elif issubclass(model_class, VisionLanguageModelBase):
|
||||||
|
if vision_language_config is None:
|
||||||
|
raise ValueError("Provide `image_input_type` and other vision "
|
||||||
|
"related configurations through LLM entrypoint "
|
||||||
|
"or engine arguments.")
|
||||||
|
|
||||||
extra_kwargs["vision_language_config"] = vision_language_config
|
extra_kwargs["vision_language_config"] = vision_language_config
|
||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
from .vlm_base import VisionLanguageModelBase
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
"language_model.lm_head": "lm_head",
|
"language_model.lm_head": "lm_head",
|
||||||
"language_model.model": "language_model",
|
"language_model.model": "language_model",
|
||||||
@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
|
|||||||
text_hidden_size,
|
text_hidden_size,
|
||||||
bias=True)
|
bias=True)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
hidden_states = self.linear_2(hidden_states)
|
hidden_states = self.linear_2(hidden_states)
|
||||||
@ -50,30 +52,32 @@ class LlavaMultiModalProjector(nn.Module):
|
|||||||
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
vision_embeddings: torch.Tensor,
|
vision_embeddings: torch.Tensor,
|
||||||
image_token_id: int):
|
image_token_id: int) -> torch.Tensor:
|
||||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
mask = (input_ids == image_token_id)
|
mask = (input_ids == image_token_id)
|
||||||
inputs_embeds[mask] = vision_embeddings.view(-1,
|
|
||||||
|
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
|
||||||
|
if mask.sum() != image_feature_size:
|
||||||
|
raise ValueError(f"image_feature_size should be {image_feature_size}, "
|
||||||
|
f"but found: {mask.sum()}")
|
||||||
|
|
||||||
|
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
|
||||||
vision_embeddings.shape[-1])
|
vision_embeddings.shape[-1])
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
class LlavaForConditionalGeneration(nn.Module):
|
|
||||||
|
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: "LlavaConfig",
|
config: LlavaConfig,
|
||||||
vision_language_config: VisionLanguageConfig,
|
vision_language_config: VisionLanguageConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional["QuantizationConfig"] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
super().__init__()
|
super().__init__(vision_language_config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.vision_language_config = vision_language_config
|
|
||||||
|
|
||||||
assert self.vision_language_config, (
|
|
||||||
"Provide `image_input_type` and other vision "
|
|
||||||
"related configurations through LLM entrypoint "
|
|
||||||
"or engine arguments.")
|
|
||||||
|
|
||||||
if self.vision_language_config.image_input_type == (
|
if self.vision_language_config.image_input_type == (
|
||||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||||
@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
|
|||||||
config.vocab_size, logit_scale)
|
config.vocab_size, logit_scale)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(self,
|
||||||
self,
|
input_ids: torch.Tensor,
|
||||||
input_ids: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
kv_caches: List[torch.Tensor],
|
||||||
kv_caches: List[torch.Tensor],
|
attn_metadata: AttentionMetadata,
|
||||||
attn_metadata: AttentionMetadata,
|
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
|
||||||
image_input: Optional[torch.Tensor] = None
|
|
||||||
) -> SamplerOutput: # noqa: E501
|
|
||||||
"""Run forward pass for Llava 1.5.
|
"""Run forward pass for Llava 1.5.
|
||||||
|
|
||||||
One key thing to understand is the `input_ids` already accounts for the
|
One key thing to understand is the `input_ids` already accounts for the
|
||||||
@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
|||||||
image_features = image_input
|
image_features = image_input
|
||||||
vision_embeddings = self.multi_modal_projector(image_features)
|
vision_embeddings = self.multi_modal_projector(image_features)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
_merge_vision_embeddings(
|
inputs_embeds = _merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.vision_language_config.image_token_id)
|
self.vision_language_config.image_token_id)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
12
vllm/model_executor/models/vlm_base.py
Normal file
12
vllm/model_executor/models/vlm_base.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import VisionLanguageConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VisionLanguageModelBase(nn.Module):
|
||||||
|
"""Base class for all vision language models (VLMs)."""
|
||||||
|
|
||||||
|
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vision_language_config = vision_language_config
|
Loading…
x
Reference in New Issue
Block a user