[Model] LLaVA model refactor (#4910)
This commit is contained in:
parent
b57e6c5949
commit
6287537a0c
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
|
|||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: (batch_size, num_channels, height, width)"""
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaImageFeatureInputs(TypedDict):
|
||||||
|
type: Literal["image_features"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: (batch_size, image_feature_size, hidden_size)"""
|
||||||
|
|
||||||
|
|
||||||
|
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||||
|
|
||||||
|
|
||||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -102,6 +117,90 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
config.vocab_size, logit_scale)
|
config.vocab_size, logit_scale)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
if list(data.shape[1:]) != list(
|
||||||
|
self.vision_language_config.image_input_shape[1:]):
|
||||||
|
raise ValueError(
|
||||||
|
f"The expected image tensor shape is batch dimension plus "
|
||||||
|
f"{self.vision_language_config.image_input_shape[1:]}. "
|
||||||
|
f"You supplied {data.shape}. "
|
||||||
|
f"If you are using vLLM's entrypoint, make sure your "
|
||||||
|
f"supplied image input is consistent with "
|
||||||
|
f"image_input_shape in engine args.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, data: object) -> Optional[LlavaImageInputs]:
|
||||||
|
expected_input_type = self.vision_language_config.image_input_type
|
||||||
|
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
raise TypeError("Image pixel vector should be a tensor, "
|
||||||
|
f"but received type: {type(data)}")
|
||||||
|
|
||||||
|
return LlavaImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_image_data(data),
|
||||||
|
)
|
||||||
|
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
raise TypeError("Image feature vector should be a tensor, "
|
||||||
|
f"but received type: {type(data)}")
|
||||||
|
|
||||||
|
return LlavaImageFeatureInputs(
|
||||||
|
type="image_features",
|
||||||
|
data=self._validate_image_data(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
|
strategy: str) -> torch.Tensor:
|
||||||
|
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
||||||
|
if strategy == "default":
|
||||||
|
return image_features[:, 1:]
|
||||||
|
elif strategy == "full":
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
|
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
|
||||||
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
||||||
|
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
|
||||||
|
output_hidden_states=True)
|
||||||
|
|
||||||
|
image_features = image_outputs.hidden_states[
|
||||||
|
self.config.vision_feature_layer]
|
||||||
|
|
||||||
|
return self._select_image_features(
|
||||||
|
image_features,
|
||||||
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_image_pixels(self,
|
||||||
|
inputs: LlavaImagePixelInputs) -> torch.Tensor:
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
|
pixel_values = inputs["data"]
|
||||||
|
|
||||||
|
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||||
|
|
||||||
|
def _process_image_input(self,
|
||||||
|
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||||
|
if image_input["type"] == "pixel_values":
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
image_features = self._process_image_pixels(image_input)
|
||||||
|
else:
|
||||||
|
image_features = image_input["data"]
|
||||||
|
|
||||||
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -144,42 +243,20 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
|||||||
For PIXEL_VALUES, expecting [1, 3, 336, 336].
|
For PIXEL_VALUES, expecting [1, 3, 336, 336].
|
||||||
For IMAGE_FEATURES, expecting [1, 576, 1024].
|
For IMAGE_FEATURES, expecting [1, 576, 1024].
|
||||||
"""
|
"""
|
||||||
if image_input is not None:
|
parsed_image_input = self._parse_and_validate_image_input(image_input)
|
||||||
if list(image_input.shape[1:]) != list(
|
|
||||||
self.vision_language_config.image_input_shape[1:]):
|
if parsed_image_input is not None:
|
||||||
raise ValueError(
|
vision_embeddings = self._process_image_input(parsed_image_input)
|
||||||
f"The expected image tensor shape is batch dimension "
|
|
||||||
f"plus "
|
|
||||||
f"{self.vision_language_config.image_input_shape[1:]}."
|
|
||||||
f" You supplied {image_input.shape}. "
|
|
||||||
f"If you are using vLLM's entrypoint, make sure your "
|
|
||||||
f"supplied image input is consistent with "
|
|
||||||
f"image_input_shape in engine args.")
|
|
||||||
if self.vision_tower is not None:
|
|
||||||
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
|
||||||
image_outputs = self.vision_tower(image_input,
|
|
||||||
output_hidden_states=True)
|
|
||||||
image_features = image_outputs.hidden_states[
|
|
||||||
self.config.vision_feature_layer]
|
|
||||||
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
|
||||||
if self.config.vision_feature_select_strategy == "default":
|
|
||||||
image_features = image_features[:, 1:]
|
|
||||||
elif self.config.vision_feature_select_strategy == "full":
|
|
||||||
image_features = image_features
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected select feature strategy: "
|
|
||||||
f"{self.config.vision_feature_select_strategy}")
|
|
||||||
else:
|
|
||||||
image_features = image_input
|
|
||||||
vision_embeddings = self.multi_modal_projector(image_features)
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
inputs_embeds = _merge_vision_embeddings(
|
inputs_embeds = _merge_vision_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.vision_language_config.image_token_id)
|
self.vision_language_config.image_token_id)
|
||||||
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user