[Model] Add Support for Multimodal Granite Models (#10291)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
f0e0238016
commit
1cfde82ffd
@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
consecutive_placeholder_ranges,
|
consecutive_placeholder_ranges,
|
||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens,
|
||||||
|
resolve_visual_encoder_outputs)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
from .utils import get_vit_attn_backend
|
||||||
@ -389,12 +390,20 @@ class CLIPEncoder(nn.Module):
|
|||||||
for layer_idx in range(num_hidden_layers)
|
for layer_idx in range(num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, inputs_embeds: torch.Tensor):
|
def forward(
|
||||||
|
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
hidden_states_pool = []
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
hidden_states = encoder_layer(hidden_states)
|
hidden_states = encoder_layer(hidden_states)
|
||||||
|
if return_all_hidden_states:
|
||||||
|
hidden_states_pool.append(hidden_states)
|
||||||
|
# If we have multiple feature sample layers, we return all hidden
|
||||||
|
# states in order and grab the ones we need by index.
|
||||||
|
if return_all_hidden_states:
|
||||||
|
return hidden_states_pool
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -419,6 +428,7 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
||||||
# the original transformers code and name of the model weights.
|
# the original transformers code and name of the model weights.
|
||||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.encoder = CLIPEncoder(
|
self.encoder = CLIPEncoder(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -446,16 +456,26 @@ class CLIPVisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values)
|
||||||
hidden_states = self.pre_layrnorm(hidden_states)
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
hidden_states = self.encoder(inputs_embeds=hidden_states)
|
|
||||||
|
|
||||||
if self.post_layernorm is None:
|
return_all_hidden_states = feature_sample_layers is not None
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
return self.post_layernorm(hidden_states)
|
# Produces either the last layer output or all of the hidden states,
|
||||||
|
# depending on if we have feature_sample_layers or not
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
return_all_hidden_states=return_all_hidden_states)
|
||||||
|
|
||||||
|
# Handle post-norm (if applicable) and stacks feature layers if needed
|
||||||
|
encoder_outputs = resolve_visual_encoder_outputs(
|
||||||
|
encoder_outputs, feature_sample_layers, self.post_layernorm,
|
||||||
|
self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
return encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModel(nn.Module):
|
class CLIPVisionModel(nn.Module):
|
||||||
@ -478,11 +498,14 @@ class CLIPVisionModel(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
require_post_norm=require_post_norm,
|
require_post_norm=require_post_norm,
|
||||||
prefix=f"{prefix}.vision_model",
|
prefix=f"{prefix}.vision_model")
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
return self.vision_model(pixel_values)
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.vision_model(pixel_values, feature_sample_layers)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
|
@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|||||||
|
|
||||||
class LlavaLikeConfig(Protocol):
|
class LlavaLikeConfig(Protocol):
|
||||||
vision_config: PretrainedConfig
|
vision_config: PretrainedConfig
|
||||||
vision_feature_layer: int
|
vision_feature_layer: Union[int, List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||||
|
"""Determine the number of hidden layers to initialize up to in the
|
||||||
|
visual encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_config: Model config with vision feature layer(s).
|
||||||
|
"""
|
||||||
|
feature_layers = hf_config.vision_feature_layer
|
||||||
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
||||||
|
# If we have one feature layer, initialize up to that layer
|
||||||
|
if isinstance(feature_layers, int):
|
||||||
|
return _get_layer_index(feature_layers, num_hidden_layers)
|
||||||
|
# If we have multiple feature layers, initialize up to the deepest one
|
||||||
|
elif isinstance(feature_layers, (list, tuple)):
|
||||||
|
return max(
|
||||||
|
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
||||||
|
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
||||||
|
" is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||||
|
"""Given an signed vision feature layer, get the number of hidden layers
|
||||||
|
needed to leverage it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_layer_index: Index of a required layer in the visual encoder.
|
||||||
|
num_hidden_layers: The total number of hidden layers in the visual
|
||||||
|
encoder.
|
||||||
|
"""
|
||||||
|
if feature_layer_index < 0:
|
||||||
|
return num_hidden_layers + feature_layer_index + 1
|
||||||
|
return feature_layer_index + 1
|
||||||
|
|
||||||
|
|
||||||
def init_vision_tower_for_llava(
|
def init_vision_tower_for_llava(
|
||||||
@ -216,13 +250,8 @@ def init_vision_tower_for_llava(
|
|||||||
):
|
):
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
# Initialize the vision tower only up to the required feature layer
|
# Initialize the vision tower only up to the deepest required feature layer
|
||||||
vision_feature_layer = hf_config.vision_feature_layer
|
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
||||||
if vision_feature_layer < 0:
|
|
||||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
|
||||||
+ vision_feature_layer + 1
|
|
||||||
else:
|
|
||||||
num_hidden_layers = vision_feature_layer + 1
|
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
return CLIPVisionModel(
|
return CLIPVisionModel(
|
||||||
|
@ -288,6 +288,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
multimodal_config = vllm_config.model_config.multimodal_config
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
|
||||||
|
vision_feature_layer = config.vision_feature_layer
|
||||||
|
# Determine the layer up to which we will initialize the vision tower
|
||||||
|
if isinstance(vision_feature_layer, int):
|
||||||
|
vision_hidden_size = config.vision_config.hidden_size
|
||||||
|
self.feature_sample_layers = None
|
||||||
|
# Used for multimodal granite models to control encoder outputs
|
||||||
|
elif isinstance(vision_feature_layer, (list, tuple)):
|
||||||
|
vision_hidden_size = config.vision_config.hidden_size * len(
|
||||||
|
vision_feature_layer)
|
||||||
|
self.feature_sample_layers = vision_feature_layer
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"vision_layer_feature type: {type(vision_feature_layer)}"
|
||||||
|
" is not supported")
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
@ -300,7 +315,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.image_newline = nn.Parameter(
|
self.image_newline = nn.Parameter(
|
||||||
torch.empty(config.text_config.hidden_size))
|
torch.empty(config.text_config.hidden_size))
|
||||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||||
vision_hidden_size=config.vision_config.hidden_size,
|
vision_hidden_size=vision_hidden_size,
|
||||||
text_hidden_size=config.text_config.hidden_size,
|
text_hidden_size=config.text_config.hidden_size,
|
||||||
projector_hidden_act=config.projector_hidden_act)
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
|
||||||
@ -419,7 +434,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
# this is already done inside the vision tower
|
# this is already done inside the vision tower
|
||||||
image_features = vision_tower(pixel_values)
|
image_features = vision_tower(
|
||||||
|
pixel_values, feature_sample_layers=self.feature_sample_layers)
|
||||||
|
|
||||||
return self._select_image_features(
|
return self._select_image_features(
|
||||||
image_features,
|
image_features,
|
||||||
|
@ -33,7 +33,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
consecutive_placeholder_ranges)
|
consecutive_placeholder_ranges,
|
||||||
|
resolve_visual_encoder_outputs)
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
@ -970,9 +971,18 @@ class PixtralHFTransformer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
|
return_all_hidden_states: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
hidden_states_pool = []
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, attention_mask, position_embeddings)
|
x = layer(x, attention_mask, position_embeddings)
|
||||||
|
if return_all_hidden_states:
|
||||||
|
hidden_states_pool.append(x)
|
||||||
|
# If we have multiple feature sample layers, we return all hidden
|
||||||
|
# states in order and grab the ones we need by index.
|
||||||
|
if return_all_hidden_states:
|
||||||
|
return hidden_states_pool
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -990,6 +1000,7 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.patch_conv = nn.Conv2d(
|
self.patch_conv = nn.Conv2d(
|
||||||
in_channels=config.num_channels,
|
in_channels=config.num_channels,
|
||||||
out_channels=config.hidden_size,
|
out_channels=config.hidden_size,
|
||||||
@ -1024,6 +1035,7 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: List[torch.Tensor],
|
pixel_values: List[torch.Tensor],
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1031,6 +1043,9 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
in pixel_values. This means it will be a list of tensors
|
in pixel_values. This means it will be a list of tensors
|
||||||
because multiple requests batched can have multiple images,
|
because multiple requests batched can have multiple images,
|
||||||
each with their own shape potentially
|
each with their own shape potentially
|
||||||
|
feature_sample_layers: Layer indices whose features should be
|
||||||
|
concatenated and used as the visual encoder output. If none
|
||||||
|
are provided, the last layer is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
image_features: tensor of token features for
|
image_features: tensor of token features for
|
||||||
@ -1065,8 +1080,15 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||||
patch_embeds)
|
patch_embeds)
|
||||||
|
|
||||||
out = self.transformer(patch_embeds, attention_mask,
|
return_all_hidden_states = feature_sample_layers is not None
|
||||||
position_embedding)
|
out = self.transformer(
|
||||||
|
patch_embeds,
|
||||||
|
attention_mask,
|
||||||
|
position_embedding,
|
||||||
|
return_all_hidden_states=return_all_hidden_states)
|
||||||
|
|
||||||
|
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
||||||
|
self.config.num_hidden_layers)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -25,7 +25,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
consecutive_placeholder_ranges,
|
consecutive_placeholder_ranges,
|
||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens,
|
||||||
|
resolve_visual_encoder_outputs)
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
from .utils import get_vit_attn_backend
|
from .utils import get_vit_attn_backend
|
||||||
@ -450,11 +451,19 @@ class SiglipEncoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
return_all_hidden_states: bool,
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
hidden_states_pool = []
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
hidden_states, _ = encoder_layer(hidden_states)
|
hidden_states, _ = encoder_layer(hidden_states)
|
||||||
|
if return_all_hidden_states:
|
||||||
|
hidden_states_pool.append(hidden_states)
|
||||||
|
# If we have multiple feature sample layers, we return all hidden
|
||||||
|
# states in order and grab the ones we need by index.
|
||||||
|
if return_all_hidden_states:
|
||||||
|
return hidden_states_pool
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -509,6 +518,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.embeddings = SiglipVisionEmbeddings(config)
|
self.embeddings = SiglipVisionEmbeddings(config)
|
||||||
|
|
||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -546,23 +556,33 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
interpolate_pos_encoding: bool = True,
|
interpolate_pos_encoding: bool = True,
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
return_all_hidden_states = feature_sample_layers is not None
|
||||||
|
|
||||||
if self.post_layernorm is None:
|
# Produces either the last layer output or all of the hidden states,
|
||||||
return encoder_outputs
|
# depending on if we have feature_sample_layers or not
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
return_all_hidden_states=return_all_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
# Handle post-norm (if applicable) and stacks feature layers if needed
|
||||||
# TODO: add this back when pooled_output is used in inference
|
encoder_outputs = resolve_visual_encoder_outputs(
|
||||||
|
encoder_outputs, feature_sample_layers, self.post_layernorm,
|
||||||
|
self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# TODO: add this back when pooled_output is used in inference.
|
||||||
# if self.use_head:
|
# if self.use_head:
|
||||||
# pooled_output = self.head(last_hidden_state)
|
# pooled_output = self.head(encoder_outputs)
|
||||||
|
|
||||||
return last_hidden_state
|
return encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
class SiglipVisionModel(nn.Module):
|
class SiglipVisionModel(nn.Module):
|
||||||
@ -595,10 +615,12 @@ class SiglipVisionModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.vision_model(
|
return self.vision_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
feature_sample_layers=feature_sample_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, TypeVar, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -392,6 +393,49 @@ def encode_video_base64(frames: npt.NDArray):
|
|||||||
return ",".join(base64_frames)
|
return ",".join(base64_frames)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_visual_encoder_outputs(
|
||||||
|
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
feature_sample_layers: Optional[list[int]],
|
||||||
|
post_layer_norm: Optional[torch.nn.LayerNorm],
|
||||||
|
max_possible_layers: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Given the outputs a visual encoder module that may correspond to the
|
||||||
|
output of the last layer, or a list of hidden states to be stacked,
|
||||||
|
handle post normalization and resolve it into a single output tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_outputs: Output of encoder's last layer or all hidden states.
|
||||||
|
feature_sample_layers: Optional layer indices to grab from the encoder
|
||||||
|
outputs; if provided, encoder outputs must be a list.
|
||||||
|
post_layer_norm: Post norm to apply to the output of the encoder.
|
||||||
|
max_possible_layers: Total layers in the fully loaded visual encoder.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if feature_sample_layers is None:
|
||||||
|
if post_layer_norm is not None:
|
||||||
|
return post_layer_norm(encoder_outputs)
|
||||||
|
return encoder_outputs
|
||||||
|
|
||||||
|
# Get the hidden states corresponding to the layer indices.
|
||||||
|
# Negative values are relative to the full visual encoder,
|
||||||
|
# so offset them depending on how many layers were loaded.
|
||||||
|
# NOTE: this assumes that encoder_outputs contains a list
|
||||||
|
# of hidden states in the same order as the encoder layers
|
||||||
|
# that produced them.
|
||||||
|
offset = max_possible_layers - len(encoder_outputs)
|
||||||
|
hs_pool = [
|
||||||
|
encoder_outputs[layer_idx]
|
||||||
|
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
|
||||||
|
for layer_idx in feature_sample_layers
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply post-norm on the final hidden state if we are using it
|
||||||
|
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
||||||
|
if post_layer_norm is not None and uses_last_layer:
|
||||||
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||||
|
return torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
# Utilities for input processors
|
# Utilities for input processors
|
||||||
_T = TypeVar("_T", str, int)
|
_T = TypeVar("_T", str, int)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user