[Model] Add Support for Multimodal Granite Models (#10291)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks 2024-11-21 03:46:20 -07:00 committed by GitHub
parent f0e0238016
commit 1cfde82ffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 191 additions and 35 deletions

View File

@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
@ -389,12 +390,20 @@ class CLIPEncoder(nn.Module):
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
@ -419,6 +428,7 @@ class CLIPVisionTransformer(nn.Module):
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
@ -446,16 +456,26 @@ class CLIPVisionTransformer(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return_all_hidden_states = feature_sample_layers is not None
return self.post_layernorm(hidden_states)
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
return encoder_outputs
class CLIPVisionModel(nn.Module):
@ -478,11 +498,14 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
prefix=f"{prefix}.vision_model")
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)
@property
def device(self):

View File

@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
vision_feature_layer: Union[int, List[int]]
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given an signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
feature_layer_index: Index of a required layer in the visual encoder.
num_hidden_layers: The total number of hidden layers in the visual
encoder.
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index + 1
def init_vision_tower_for_llava(
@ -216,13 +250,8 @@ def init_vision_tower_for_llava(
):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# Initialize the vision tower only up to the deepest required feature layer
num_hidden_layers = _get_num_hidden_layers(hf_config)
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(

View File

@ -288,6 +288,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.feature_sample_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
self.feature_sample_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
" is not supported")
self.config = config
self.multimodal_config = multimodal_config
@ -300,7 +315,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
@ -419,7 +434,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = vision_tower(
pixel_values, feature_sample_layers=self.feature_sample_layers)
return self._select_image_features(
image_features,

View File

@ -33,7 +33,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
@ -970,9 +971,18 @@ class PixtralHFTransformer(nn.Module):
x: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor:
hidden_states_pool = []
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
if return_all_hidden_states:
hidden_states_pool.append(x)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return x
@ -990,6 +1000,7 @@ class PixtralHFVisionModel(nn.Module):
super().__init__()
self.config = config
self.patch_conv = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
@ -1024,6 +1035,7 @@ class PixtralHFVisionModel(nn.Module):
def forward(
self,
pixel_values: List[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
"""
Args:
@ -1031,6 +1043,9 @@ class PixtralHFVisionModel(nn.Module):
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
image_features: tensor of token features for
@ -1065,8 +1080,15 @@ class PixtralHFVisionModel(nn.Module):
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
out = self.transformer(patch_embeds, attention_mask,
position_embedding)
return_all_hidden_states = feature_sample_layers is not None
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
return out

View File

@ -25,7 +25,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
@ -450,11 +451,19 @@ class SiglipEncoder(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
) -> torch.Tensor:
return_all_hidden_states: bool,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
@ -509,6 +518,7 @@ class SiglipVisionTransformer(nn.Module):
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
@ -546,23 +556,33 @@ class SiglipVisionTransformer(nn.Module):
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = True,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return_all_hidden_states = feature_sample_layers is not None
if self.post_layernorm is None:
return encoder_outputs
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
# TODO: add this back when pooled_output is used in inference.
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
# pooled_output = self.head(encoder_outputs)
return last_hidden_state
return encoder_outputs
class SiglipVisionModel(nn.Module):
@ -595,10 +615,12 @@ class SiglipVisionModel(nn.Module):
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
feature_sample_layers=feature_sample_layers,
)
def load_weights(self, weights: Iterable[Tuple[str,

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
import vllm.envs as envs
@ -392,6 +393,49 @@ def encode_video_base64(frames: npt.NDArray):
return ",".join(base64_frames)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
# Utilities for input processors
_T = TypeVar("_T", str, int)