[Bugfix] Multi-video inference on LLaVA-Onevision (#15082)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
e3f813c33b
commit
27261e40a6
@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
|||||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||||
from vllm.multimodal.profiling import ProcessorInputs
|
from vllm.multimodal.profiling import ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
|
|||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
|
||||||
|
|
||||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||||
may be different for each video, in which case the data is passed as a
|
may be different for each video, in which case the data is passed as a
|
||||||
@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
return LlavaOnevisionVideoPixelInputs(
|
return LlavaOnevisionVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=flatten_bn(pixel_values_videos),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _add_image_newline(
|
|
||||||
self,
|
|
||||||
video_features: torch.Tensor,
|
|
||||||
videos: int = 1,
|
|
||||||
frames: int = 1,
|
|
||||||
strategy: str = "one_token",
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if strategy == "one_token":
|
|
||||||
video_features = video_features.reshape(
|
|
||||||
videos, frames * video_features.shape[1], -1)
|
|
||||||
image_newline = self.image_newline[None, None, :].repeat(
|
|
||||||
videos, 1, 1).to(video_features.device)
|
|
||||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
|
||||||
return video_features
|
|
||||||
raise ValueError(f"Unexpected video newline strategy: {strategy}")
|
|
||||||
|
|
||||||
def _video_pixels_to_features(
|
def _video_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||||
@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
video_pixels = inputs["pixel_values_videos"]
|
video_pixels = inputs["pixel_values_videos"]
|
||||||
|
|
||||||
if isinstance(video_pixels, torch.Tensor):
|
if isinstance(video_pixels, torch.Tensor):
|
||||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
total_videos, frames, c, h, w = video_pixels.shape
|
||||||
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
|
video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
|
||||||
stacked_embeddings = self._video_pixels_to_features(
|
w)
|
||||||
self.vision_tower, pixel_values)
|
|
||||||
stacked_embeddings = self._add_image_newline(stacked_embeddings,
|
|
||||||
videos=b * num_videos,
|
|
||||||
frames=frames,
|
|
||||||
strategy="one_token")
|
|
||||||
return stacked_embeddings
|
|
||||||
elif is_list_of(video_pixels, torch.Tensor):
|
|
||||||
stacked_embeddings = []
|
|
||||||
for video_pixel in video_pixels:
|
|
||||||
num_videos, frames, c, h, w = video_pixel.shape
|
|
||||||
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
|
|
||||||
embeddings = self._video_pixels_to_features(
|
|
||||||
self.vision_tower, pixel_values)
|
|
||||||
embeddings = self._add_image_newline(embeddings,
|
|
||||||
videos=num_videos,
|
|
||||||
frames=frames,
|
|
||||||
strategy="one_token")
|
|
||||||
stacked_embeddings.append(embeddings)
|
|
||||||
return stacked_embeddings
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported type of video input {type(video_pixels)}")
|
|
||||||
|
|
||||||
def apply_pooling(self, image_features, stride=2):
|
embeddings_flat = self._video_pixels_to_features(
|
||||||
|
self.vision_tower, video_pixels_flat)
|
||||||
|
|
||||||
|
embeddings_flat = embeddings_flat.reshape(
|
||||||
|
total_videos, frames * embeddings_flat.shape[1], -1)
|
||||||
|
|
||||||
|
image_newline = self.image_newline[None, None, :].expand(
|
||||||
|
total_videos, -1, -1)
|
||||||
|
return torch.cat((embeddings_flat, image_newline), dim=1)
|
||||||
|
|
||||||
|
frames_per_video = [len(video) for video in video_pixels]
|
||||||
|
video_pixels_flat = torch.cat(video_pixels)
|
||||||
|
|
||||||
|
embeddings_flat = self._video_pixels_to_features(
|
||||||
|
self.vision_tower, video_pixels_flat)
|
||||||
|
|
||||||
|
image_newline = self.image_newline[None, None, :]
|
||||||
|
|
||||||
|
return [
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
|
||||||
|
-1),
|
||||||
|
image_newline,
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
) for num_frame, embeds in zip(
|
||||||
|
frames_per_video,
|
||||||
|
torch.split(embeddings_flat, frames_per_video),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
|
||||||
vision_config = self.config.vision_config
|
vision_config = self.config.vision_config
|
||||||
height = width = vision_config.image_size // vision_config.patch_size
|
height = width = vision_config.image_size // vision_config.patch_size
|
||||||
batch_frames, _, dim = image_features.shape
|
batch_frames, _, dim = image_features.shape
|
||||||
|
Loading…
x
Reference in New Issue
Block a user