[Bugfix] Multi-video inference on LLaVA-Onevision (#15082)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2025-03-20 22:10:45 +08:00 committed by GitHub
parent e3f813c33b
commit 27261e40a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)` Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
Note that `num_videos` may be different for each batch, and 'num_frames' Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a may be different for each video, in which case the data is passed as a
@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionVideoPixelInputs( return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values_videos=pixel_values_videos, pixel_values_videos=flatten_bn(pixel_values_videos),
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features( def _video_pixels_to_features(
self, self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel], vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
video_pixels = inputs["pixel_values_videos"] video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor): if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape total_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w) video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
stacked_embeddings = self._video_pixels_to_features( w)
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings,
videos=b * num_videos,
frames=frames,
strategy="one_token")
return stacked_embeddings
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = []
for video_pixel in video_pixels:
num_videos, frames, c, h, w = video_pixel.shape
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features(
self.vision_tower, pixel_values)
embeddings = self._add_image_newline(embeddings,
videos=num_videos,
frames=frames,
strategy="one_token")
stacked_embeddings.append(embeddings)
return stacked_embeddings
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def apply_pooling(self, image_features, stride=2): embeddings_flat = self._video_pixels_to_features(
self.vision_tower, video_pixels_flat)
embeddings_flat = embeddings_flat.reshape(
total_videos, frames * embeddings_flat.shape[1], -1)
image_newline = self.image_newline[None, None, :].expand(
total_videos, -1, -1)
return torch.cat((embeddings_flat, image_newline), dim=1)
frames_per_video = [len(video) for video in video_pixels]
video_pixels_flat = torch.cat(video_pixels)
embeddings_flat = self._video_pixels_to_features(
self.vision_tower, video_pixels_flat)
image_newline = self.image_newline[None, None, :]
return [
torch.cat(
(
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
-1),
image_newline,
),
dim=1,
) for num_frame, embeds in zip(
frames_per_video,
torch.split(embeddings_flat, frames_per_video),
)
]
def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
vision_config = self.config.vision_config vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape batch_frames, _, dim = image_features.shape