[Misc] Move some multimodal utils to modality-specific modules (#11494)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-26 12:23:20 +08:00 committed by GitHub
parent 6ad909fdda
commit 51a624bf02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 84 additions and 77 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Type
import pytest import pytest
import torch import torch
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.image import rescale_image_size
from ....conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets from ....conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close from ...utils import check_logprobs_close

View File

@ -8,7 +8,7 @@ from transformers import AutoConfig
# Import the functions to test # Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks, from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper) image_to_pixel_values_wrapper)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.image import rescale_image_size
models = [ models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names "h2oai/h2ovl-mississippi-800m", # Replace with your actual model names

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Type
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs

View File

@ -6,8 +6,8 @@ import torch
from PIL import Image from PIL import Image
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size, from vllm.multimodal.image import rescale_image_size
sample_frames_from_video) from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
PromptVideoInput, VllmRunner) PromptVideoInput, VllmRunner)

View File

@ -5,8 +5,9 @@ from typing import Callable, Iterable, List, Optional, Tuple, Union
import torch import torch
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size, from vllm.multimodal.image import rescale_image_size
resize_video, sample_frames_from_video) from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)
from .....conftest import _ImageAssets, _VideoAssets from .....conftest import _ImageAssets, _VideoAssets
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,

View File

@ -1,8 +1,9 @@
"""Custom input builders for edge-cases in different models.""" """Custom input builders for edge-cases in different models."""
from typing import Callable from typing import Callable
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size, from vllm.multimodal.image import rescale_image_size
resize_video, sample_frames_from_video) from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
from .builders import build_multi_image_inputs, build_single_image_inputs from .builders import build_multi_image_inputs, build_single_image_inputs

View File

@ -6,7 +6,7 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,

View File

@ -6,7 +6,7 @@ from transformers import LlavaNextImageProcessor
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.image import rescale_image_size
@pytest.fixture @pytest.fixture

View File

@ -7,7 +7,7 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from vllm.multimodal.utils import (sample_frames_from_video, from vllm.multimodal.video import (sample_frames_from_video,
try_import_video_packages) try_import_video_packages)
from .base import get_cache_dir from .base import get_cache_dir

View File

@ -1,3 +1,5 @@
from typing import Any
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -26,6 +28,16 @@ class AudioPlugin(MultiModalPlugin):
"There is no default maximum multimodal tokens") "There is no default maximum multimodal tokens")
def try_import_audio_packages() -> tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
return librosa, soundfile
def resample_audio( def resample_audio(
audio: npt.NDArray[np.floating], audio: npt.NDArray[np.floating],
*, *,

View File

@ -84,3 +84,15 @@ class ImagePlugin(MultiModalPlugin):
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000 return 3000
def rescale_image_size(image: Image.Image,
size_factor: float,
transpose: int = -1) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image

View File

@ -2,7 +2,7 @@ import base64
import os import os
from functools import lru_cache from functools import lru_cache
from io import BytesIO from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -14,7 +14,9 @@ from vllm.connections import global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from .audio import try_import_audio_packages
from .inputs import MultiModalDataDict, PlaceholderRange from .inputs import MultiModalDataDict, PlaceholderRange
from .video import try_import_video_packages
logger = init_logger(__name__) logger = init_logger(__name__)
@ -198,16 +200,6 @@ async def async_fetch_video(video_url: str,
return video return video
def try_import_audio_packages() -> Tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
return librosa, soundfile
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
""" """
Load audio from a URL. Load audio from a URL.
@ -324,60 +316,6 @@ def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
return _load_image_from_bytes(base64.b64decode(image)) return _load_image_from_bytes(base64.b64decode(image))
def rescale_image_size(image: Image.Image,
size_factor: float,
transpose: int = -1) -> Image.Image:
"""Rescale the dimensions of an image by a constant factor."""
new_width = int(image.width * size_factor)
new_height = int(image.height * size_factor)
image = image.resize((new_width, new_height))
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
def try_import_video_packages():
try:
import cv2
import decord
except ImportError as exc:
raise ImportError(
"Please install vllm[video] for video support.") from exc
return cv2, decord
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
cv2, _ = try_import_video_packages()
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray,
num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
else:
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
def encode_video_base64(frames: npt.NDArray) -> str: def encode_video_base64(frames: npt.NDArray) -> str:
base64_frames = [] base64_frames = []
frames_list = [frames[i] for i in range(frames.shape[0])] frames_list = [frames[i] for i in range(frames.shape[0])]

View File

@ -2,6 +2,7 @@ from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import numpy as np import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
@ -75,3 +76,45 @@ class VideoPlugin(ImagePlugin):
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 4096 return 4096
def try_import_video_packages() -> tuple[Any, Any]:
try:
import cv2
import decord
except ImportError as exc:
raise ImportError(
"Please install vllm[video] for video support.") from exc
return cv2, decord
def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
cv2, _ = try_import_video_packages()
num_frames, _, _, channels = frames.shape
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame
return resized_frames
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
_, height, width, _ = frames.shape
new_height = int(height * size_factor)
new_width = int(width * size_factor)
return resize_video(frames, (new_height, new_width))
def sample_frames_from_video(frames: npt.NDArray,
num_frames: int) -> npt.NDArray:
total_frames = frames.shape[0]
if num_frames == -1:
return frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames