[Bugfix] Fix Positive Feature Layers in Llava Models (#13514)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
This commit is contained in:
parent
fdc5df6f54
commit
983a40a8bb
34
tests/models/test_vision.py
Normal file
34
tests/models/test_vision.py
Normal file
@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.vision import resolve_visual_encoder_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
|
||||
"expected_features"),
|
||||
[
|
||||
# All layers loaded
|
||||
([1, 10], 10, 10, [1, 10]),
|
||||
([-10, -1], 10, 10, [1, 10]),
|
||||
# Some layers not loaded
|
||||
([1, 10], 10, 20, [1, 10]),
|
||||
([-20, -11], 10, 20, [1, 10]),
|
||||
])
|
||||
def test_resolve_visual_encoder_outputs(feature_sample_layers,
|
||||
num_layers_loaded, max_possible_layers,
|
||||
expected_features):
|
||||
"""
|
||||
Test that offsets are correctly handled for vision feature layers.
|
||||
"""
|
||||
encoder_outputs = [
|
||||
torch.tensor([idx]) for idx in range(num_layers_loaded + 1)
|
||||
]
|
||||
output_tensor = resolve_visual_encoder_outputs(
|
||||
encoder_outputs=encoder_outputs,
|
||||
feature_sample_layers=feature_sample_layers,
|
||||
post_layer_norm=None,
|
||||
max_possible_layers=max_possible_layers)
|
||||
assert torch.equal(torch.tensor(expected_features), output_tensor)
|
@ -251,7 +251,7 @@ class CLIPEncoder(nn.Module):
|
||||
def forward(
|
||||
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
hidden_states_pool = []
|
||||
hidden_states_pool = [inputs_embeds]
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
|
@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||
|
||||
|
||||
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
"""Given an signed vision feature layer, get the number of hidden layers
|
||||
"""Given a signed vision feature layer, get the number of hidden layers
|
||||
needed to leverage it.
|
||||
|
||||
Args:
|
||||
@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
"""
|
||||
if feature_layer_index < 0:
|
||||
return num_hidden_layers + feature_layer_index + 1
|
||||
return feature_layer_index + 1
|
||||
return feature_layer_index
|
||||
|
||||
|
||||
def init_vision_tower_for_llava(
|
||||
|
@ -969,7 +969,7 @@ class PixtralHFTransformer(nn.Module):
|
||||
position_embeddings: torch.Tensor,
|
||||
return_all_hidden_states: bool,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_pool = []
|
||||
hidden_states_pool = [x]
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, attention_mask, position_embeddings)
|
||||
|
@ -378,7 +378,7 @@ class SiglipEncoder(nn.Module):
|
||||
inputs_embeds: torch.Tensor,
|
||||
return_all_hidden_states: bool,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
hidden_states_pool = []
|
||||
hidden_states_pool = [inputs_embeds]
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
|
@ -132,10 +132,11 @@ def resolve_visual_encoder_outputs(
|
||||
# Get the hidden states corresponding to the layer indices.
|
||||
# Negative values are relative to the full visual encoder,
|
||||
# so offset them depending on how many layers were loaded.
|
||||
# NOTE: this assumes that encoder_outputs contains a list
|
||||
# of hidden states in the same order as the encoder layers
|
||||
# that produced them.
|
||||
offset = max_possible_layers - len(encoder_outputs)
|
||||
# NOTE: this assumes that encoder_outputs is a list containing
|
||||
# the inputs to the visual encoder, followed by the hidden states
|
||||
# of each layer.
|
||||
num_loaded_layers = len(encoder_outputs) - 1
|
||||
offset = max_possible_layers - num_loaded_layers
|
||||
hs_pool = [
|
||||
encoder_outputs[layer_idx]
|
||||
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
|
||||
|
Loading…
x
Reference in New Issue
Block a user