[Bugfix] Fix Positive Feature Layers in Llava Models (#13514)

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
This commit is contained in:
Alex Brooks 2025-02-19 01:50:07 -07:00 committed by GitHub
parent fdc5df6f54
commit 983a40a8bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 44 additions and 9 deletions

View File

@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
@pytest.mark.parametrize(
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
"expected_features"),
[
# All layers loaded
([1, 10], 10, 10, [1, 10]),
([-10, -1], 10, 10, [1, 10]),
# Some layers not loaded
([1, 10], 10, 20, [1, 10]),
([-20, -11], 10, 20, [1, 10]),
])
def test_resolve_visual_encoder_outputs(feature_sample_layers,
num_layers_loaded, max_possible_layers,
expected_features):
"""
Test that offsets are correctly handled for vision feature layers.
"""
encoder_outputs = [
torch.tensor([idx]) for idx in range(num_layers_loaded + 1)
]
output_tensor = resolve_visual_encoder_outputs(
encoder_outputs=encoder_outputs,
feature_sample_layers=feature_sample_layers,
post_layer_norm=None,
max_possible_layers=max_possible_layers)
assert torch.equal(torch.tensor(expected_features), output_tensor)

View File

@ -251,7 +251,7 @@ class CLIPEncoder(nn.Module):
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:

View File

@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given an signed vision feature layer, get the number of hidden layers
"""Given a signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index + 1
return feature_layer_index
def init_vision_tower_for_llava(

View File

@ -969,7 +969,7 @@ class PixtralHFTransformer(nn.Module):
position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor:
hidden_states_pool = []
hidden_states_pool = [x]
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)

View File

@ -378,7 +378,7 @@ class SiglipEncoder(nn.Module):
inputs_embeds: torch.Tensor,
return_all_hidden_states: bool,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:

View File

@ -132,10 +132,11 @@ def resolve_visual_encoder_outputs(
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
# NOTE: this assumes that encoder_outputs is a list containing
# the inputs to the visual encoder, followed by the hidden states
# of each layer.
num_loaded_layers = len(encoder_outputs) - 1
offset = max_possible_layers - num_loaded_layers
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]