[Model] Add Support for Multimodal Granite Models (#10291)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks 2024-11-21 03:46:20 -07:00 committed by GitHub
parent f0e0238016
commit 1cfde82ffd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 191 additions and 35 deletions

View File

@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges, consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend from .utils import get_vit_attn_backend
@ -389,12 +390,20 @@ class CLIPEncoder(nn.Module):
for layer_idx in range(num_hidden_layers) for layer_idx in range(num_hidden_layers)
]) ])
def forward(self, inputs_embeds: torch.Tensor): def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states) hidden_states = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states return hidden_states
@ -419,6 +428,7 @@ class CLIPVisionTransformer(nn.Module):
# NOTE: This typo of "layrnorm" is not fixed on purpose to match # NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights. # the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
@ -446,16 +456,26 @@ class CLIPVisionTransformer(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values) hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None: return_all_hidden_states = feature_sample_layers is not None
return hidden_states
return self.post_layernorm(hidden_states) # Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
return encoder_outputs
class CLIPVisionModel(nn.Module): class CLIPVisionModel(nn.Module):
@ -478,11 +498,14 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm, require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model", prefix=f"{prefix}.vision_model")
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(
return self.vision_model(pixel_values) self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)
@property @property
def device(self): def device(self):

View File

@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
class LlavaLikeConfig(Protocol): class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig vision_config: PretrainedConfig
vision_feature_layer: int vision_feature_layer: Union[int, List[int]]
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given an signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
feature_layer_index: Index of a required layer in the visual encoder.
num_hidden_layers: The total number of hidden layers in the visual
encoder.
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index + 1
def init_vision_tower_for_llava( def init_vision_tower_for_llava(
@ -216,13 +250,8 @@ def init_vision_tower_for_llava(
): ):
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the deepest required feature layer
vision_feature_layer = hf_config.vision_feature_layer num_hidden_layers = _get_num_hidden_layers(hf_config)
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel( return CLIPVisionModel(

View File

@ -288,6 +288,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
vision_feature_layer = config.vision_feature_layer
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.feature_sample_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
self.feature_sample_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
" is not supported")
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
@ -300,7 +315,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
@ -419,7 +434,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: we skip the step to select the vision feature layer since # NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower # this is already done inside the vision tower
image_features = vision_tower(pixel_values) image_features = vision_tower(
pixel_values, feature_sample_layers=self.feature_sample_layers)
return self._select_image_features( return self._select_image_features(
image_features, image_features,

View File

@ -33,7 +33,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges) consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -970,9 +971,18 @@ class PixtralHFTransformer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states_pool = []
for layer in self.layers: for layer in self.layers:
x = layer(x, attention_mask, position_embeddings) x = layer(x, attention_mask, position_embeddings)
if return_all_hidden_states:
hidden_states_pool.append(x)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return x return x
@ -990,6 +1000,7 @@ class PixtralHFVisionModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.patch_conv = nn.Conv2d( self.patch_conv = nn.Conv2d(
in_channels=config.num_channels, in_channels=config.num_channels,
out_channels=config.hidden_size, out_channels=config.hidden_size,
@ -1024,6 +1035,7 @@ class PixtralHFVisionModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: List[torch.Tensor], pixel_values: List[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -1031,6 +1043,9 @@ class PixtralHFVisionModel(nn.Module):
in pixel_values. This means it will be a list of tensors in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images, because multiple requests batched can have multiple images,
each with their own shape potentially each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns: Returns:
image_features: tensor of token features for image_features: tensor of token features for
@ -1065,8 +1080,15 @@ class PixtralHFVisionModel(nn.Module):
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds) patch_embeds)
out = self.transformer(patch_embeds, attention_mask, return_all_hidden_states = feature_sample_layers is not None
position_embedding) out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
return out return out

View File

@ -25,7 +25,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges, consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend from .utils import get_vit_attn_backend
@ -450,11 +451,19 @@ class SiglipEncoder(nn.Module):
def forward( def forward(
self, self,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
) -> torch.Tensor: return_all_hidden_states: bool,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states) hidden_states, _ = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states return hidden_states
@ -509,6 +518,7 @@ class SiglipVisionTransformer(nn.Module):
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config) self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
@ -546,23 +556,33 @@ class SiglipVisionTransformer(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = True, interpolate_pos_encoding: bool = True,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings( hidden_states = self.embeddings(
pixel_values, pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
) )
encoder_outputs = self.encoder(inputs_embeds=hidden_states) return_all_hidden_states = feature_sample_layers is not None
if self.post_layernorm is None: # Produces either the last layer output or all of the hidden states,
return encoder_outputs # depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
last_hidden_state = self.post_layernorm(encoder_outputs) # Handle post-norm (if applicable) and stacks feature layers if needed
# TODO: add this back when pooled_output is used in inference encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
# TODO: add this back when pooled_output is used in inference.
# if self.use_head: # if self.use_head:
# pooled_output = self.head(last_hidden_state) # pooled_output = self.head(encoder_outputs)
return last_hidden_state return encoder_outputs
class SiglipVisionModel(nn.Module): class SiglipVisionModel(nn.Module):
@ -595,10 +615,12 @@ class SiglipVisionModel(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
feature_sample_layers=feature_sample_layers,
) )
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch
from PIL import Image from PIL import Image
import vllm.envs as envs import vllm.envs as envs
@ -392,6 +393,49 @@ def encode_video_base64(frames: npt.NDArray):
return ",".join(base64_frames) return ",".join(base64_frames)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
# Utilities for input processors # Utilities for input processors
_T = TypeVar("_T", str, int) _T = TypeVar("_T", str, int)