[Bugfix] Multi-video inference on LLaVA-Onevision (#15082)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
e3f813c33b
commit
27261e40a6
@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
||||
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
|
||||
|
||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||
may be different for each video, in which case the data is passed as a
|
||||
@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaOnevisionVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
pixel_values_videos=flatten_bn(pixel_values_videos),
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def _add_image_newline(
|
||||
self,
|
||||
video_features: torch.Tensor,
|
||||
videos: int = 1,
|
||||
frames: int = 1,
|
||||
strategy: str = "one_token",
|
||||
) -> torch.Tensor:
|
||||
if strategy == "one_token":
|
||||
video_features = video_features.reshape(
|
||||
videos, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(
|
||||
videos, 1, 1).to(video_features.device)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
return video_features
|
||||
raise ValueError(f"Unexpected video newline strategy: {strategy}")
|
||||
|
||||
def _video_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_pixels = inputs["pixel_values_videos"]
|
||||
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
||||
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
|
||||
stacked_embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, pixel_values)
|
||||
stacked_embeddings = self._add_image_newline(stacked_embeddings,
|
||||
videos=b * num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
return stacked_embeddings
|
||||
elif is_list_of(video_pixels, torch.Tensor):
|
||||
stacked_embeddings = []
|
||||
for video_pixel in video_pixels:
|
||||
num_videos, frames, c, h, w = video_pixel.shape
|
||||
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
|
||||
embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, pixel_values)
|
||||
embeddings = self._add_image_newline(embeddings,
|
||||
videos=num_videos,
|
||||
frames=frames,
|
||||
strategy="one_token")
|
||||
stacked_embeddings.append(embeddings)
|
||||
return stacked_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of video input {type(video_pixels)}")
|
||||
total_videos, frames, c, h, w = video_pixels.shape
|
||||
video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
|
||||
w)
|
||||
|
||||
def apply_pooling(self, image_features, stride=2):
|
||||
embeddings_flat = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels_flat)
|
||||
|
||||
embeddings_flat = embeddings_flat.reshape(
|
||||
total_videos, frames * embeddings_flat.shape[1], -1)
|
||||
|
||||
image_newline = self.image_newline[None, None, :].expand(
|
||||
total_videos, -1, -1)
|
||||
return torch.cat((embeddings_flat, image_newline), dim=1)
|
||||
|
||||
frames_per_video = [len(video) for video in video_pixels]
|
||||
video_pixels_flat = torch.cat(video_pixels)
|
||||
|
||||
embeddings_flat = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels_flat)
|
||||
|
||||
image_newline = self.image_newline[None, None, :]
|
||||
|
||||
return [
|
||||
torch.cat(
|
||||
(
|
||||
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
|
||||
-1),
|
||||
image_newline,
|
||||
),
|
||||
dim=1,
|
||||
) for num_frame, embeds in zip(
|
||||
frames_per_video,
|
||||
torch.split(embeddings_flat, frames_per_video),
|
||||
)
|
||||
]
|
||||
|
||||
def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
|
||||
vision_config = self.config.vision_config
|
||||
height = width = vision_config.image_size // vision_config.patch_size
|
||||
batch_frames, _, dim = image_features.shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user