[Bugfix] Multi-video inference on LLaVA-Onevision (#15082)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2025-03-20 22:10:45 +08:00 committed by GitHub
parent e3f813c33b
commit 27261e40a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
pixel_values_videos=flatten_bn(pixel_values_videos),
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings,
videos=b * num_videos,
frames=frames,
strategy="one_token")
return stacked_embeddings
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = []
for video_pixel in video_pixels:
num_videos, frames, c, h, w = video_pixel.shape
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
embeddings = self._add_image_newline(embeddings,
videos=num_videos,
frames=frames,
strategy="one_token")
stacked_embeddings.append(embeddings)
return stacked_embeddings
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
total_videos, frames, c, h, w = video_pixels.shape
video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
w)
def apply_pooling(self, image_features, stride=2):
embeddings_flat = self._video_pixels_to_features(
self.vision_tower, video_pixels_flat)
embeddings_flat = embeddings_flat.reshape(
total_videos, frames * embeddings_flat.shape[1], -1)
image_newline = self.image_newline[None, None, :].expand(
total_videos, -1, -1)
return torch.cat((embeddings_flat, image_newline), dim=1)
frames_per_video = [len(video) for video in video_pixels]
video_pixels_flat = torch.cat(video_pixels)
embeddings_flat = self._video_pixels_to_features(
self.vision_tower, video_pixels_flat)
image_newline = self.image_newline[None, None, :]
return [
torch.cat(
(
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
-1),
image_newline,
),
dim=1,
) for num_frame, embeds in zip(
frames_per_video,
torch.split(embeddings_flat, frames_per_video),
)
]
def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape