[Misc] Move some multimodal utils to modality-specific modules (#11494)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6ad909fdda
commit
51a624bf02
@ -3,7 +3,7 @@ from typing import List, Optional, Type
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
|
@ -8,7 +8,7 @@ from transformers import AutoConfig
|
||||
# Import the functions to test
|
||||
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
|
||||
image_to_pixel_values_wrapper)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
models = [
|
||||
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
|
||||
|
@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Type
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
|
@ -6,8 +6,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
|
||||
sample_frames_from_video)
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
|
||||
PromptVideoInput, VllmRunner)
|
||||
|
@ -5,8 +5,9 @@ from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
|
||||
resize_video, sample_frames_from_video)
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.video import (rescale_video_size, resize_video,
|
||||
sample_frames_from_video)
|
||||
|
||||
from .....conftest import _ImageAssets, _VideoAssets
|
||||
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Custom input builders for edge-cases in different models."""
|
||||
from typing import Callable
|
||||
|
||||
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
|
||||
resize_video, sample_frames_from_video)
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.video import (rescale_video_size, resize_video,
|
||||
sample_frames_from_video)
|
||||
|
||||
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
|
||||
from .builders import build_multi_image_inputs, build_single_image_inputs
|
||||
|
@ -6,7 +6,7 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
|
@ -6,7 +6,7 @@ from transformers import LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -7,7 +7,7 @@ import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.utils import (sample_frames_from_video,
|
||||
from vllm.multimodal.video import (sample_frames_from_video,
|
||||
try_import_video_packages)
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
@ -26,6 +28,16 @@ class AudioPlugin(MultiModalPlugin):
|
||||
"There is no default maximum multimodal tokens")
|
||||
|
||||
|
||||
def try_import_audio_packages() -> tuple[Any, Any]:
|
||||
try:
|
||||
import librosa
|
||||
import soundfile
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[audio] for audio support.") from exc
|
||||
return librosa, soundfile
|
||||
|
||||
|
||||
def resample_audio(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
|
@ -84,3 +84,15 @@ class ImagePlugin(MultiModalPlugin):
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
return 3000
|
||||
|
||||
|
||||
def rescale_image_size(image: Image.Image,
|
||||
size_factor: float,
|
||||
transpose: int = -1) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
@ -2,7 +2,7 @@ import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -14,7 +14,9 @@ from vllm.connections import global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
from .audio import try_import_audio_packages
|
||||
from .inputs import MultiModalDataDict, PlaceholderRange
|
||||
from .video import try_import_video_packages
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -198,16 +200,6 @@ async def async_fetch_video(video_url: str,
|
||||
return video
|
||||
|
||||
|
||||
def try_import_audio_packages() -> Tuple[Any, Any]:
|
||||
try:
|
||||
import librosa
|
||||
import soundfile
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[audio] for audio support.") from exc
|
||||
return librosa, soundfile
|
||||
|
||||
|
||||
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
@ -324,60 +316,6 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
|
||||
return _load_image_from_bytes(base64.b64decode(image))
|
||||
|
||||
|
||||
def rescale_image_size(image: Image.Image,
|
||||
size_factor: float,
|
||||
transpose: int = -1) -> Image.Image:
|
||||
"""Rescale the dimensions of an image by a constant factor."""
|
||||
new_width = int(image.width * size_factor)
|
||||
new_height = int(image.height * size_factor)
|
||||
image = image.resize((new_width, new_height))
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
def try_import_video_packages():
|
||||
try:
|
||||
import cv2
|
||||
import decord
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[video] for video support.") from exc
|
||||
return cv2, decord
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
dtype=frames.dtype)
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray,
|
||||
num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
else:
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
base64_frames = []
|
||||
frames_list = [frames[i] for i in range(frames.shape[0])]
|
||||
|
@ -2,6 +2,7 @@ from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
@ -75,3 +76,45 @@ class VideoPlugin(ImagePlugin):
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
return 4096
|
||||
|
||||
|
||||
def try_import_video_packages() -> tuple[Any, Any]:
|
||||
try:
|
||||
import cv2
|
||||
import decord
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[video] for video support.") from exc
|
||||
return cv2, decord
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
dtype=frames.dtype)
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
return resized_frames
|
||||
|
||||
|
||||
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||
_, height, width, _ = frames.shape
|
||||
new_height = int(height * size_factor)
|
||||
new_width = int(width * size_factor)
|
||||
|
||||
return resize_video(frames, (new_height, new_width))
|
||||
|
||||
|
||||
def sample_frames_from_video(frames: npt.NDArray,
|
||||
num_frames: int) -> npt.NDArray:
|
||||
total_frames = frames.shape[0]
|
||||
if num_frames == -1:
|
||||
return frames
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
Loading…
x
Reference in New Issue
Block a user