From 4e5a0f6ae208c56c169de58a2f3a02c533d9ec00 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 1 Apr 2025 23:55:13 +0800 Subject: [PATCH] [Misc] Allow using OpenCV as video IO fallback (#15055) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- requirements/common.txt | 1 + requirements/test.in | 2 +- requirements/test.txt | 9 ++--- setup.py | 2 +- vllm/assets/video.py | 19 +++++---- vllm/multimodal/video.py | 86 ++++++++++++++++++++++++++++++---------- 6 files changed, 84 insertions(+), 35 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index c7bbdb71..48e58c85 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -30,6 +30,7 @@ msgspec gguf == 0.10.0 importlib_metadata mistral_common[opencv] >= 1.5.4 +opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 diff --git a/requirements/test.in b/requirements/test.in index cf89794b..c1b70bca 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -9,7 +9,6 @@ pytest-shard # testing utils awscli backoff # required for phi4mm test -decord # required for video tests einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests @@ -28,6 +27,7 @@ torchvision==0.21.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.4 # required for pixtral test +opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test transformers==4.50.3 diff --git a/requirements/test.txt b/requirements/test.txt index 26ed9dbe..c46fa072 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -93,8 +93,6 @@ datasets==3.0.2 # lm-eval decorator==5.1.1 # via librosa -decord==0.6.0 - # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -276,7 +274,6 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets - # decord # einx # encodec # evaluate @@ -337,8 +334,10 @@ nvidia-nvjitlink-cu12==12.4.127 # torch nvidia-nvtx-cu12==12.4.127 # via torch -opencv-python-headless==4.10.0.84 - # via mistral-common +opencv-python-headless==4.11.0.86 + # via + # -r requirements/test.in + # mistral-common packaging==24.1 # via # accelerate diff --git a/setup.py b/setup.py index cf2acb20..b0cc2f48 100755 --- a/setup.py +++ b/setup.py @@ -684,7 +684,7 @@ setup( "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile"], # Required for audio processing - "video": ["decord"] # Required for video processing + "video": [] # Kept for backwards compatibility }, cmdclass=cmdclass, package_data=package_data, diff --git a/vllm/assets/video.py b/vllm/assets/video.py index e45e1a65..32b0b86b 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -10,8 +10,6 @@ import numpy.typing as npt from huggingface_hub import hf_hub_download from PIL import Image -from vllm.multimodal.video import sample_frames_from_video - from .base import get_cache_dir @@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] - for i in range(total_frames): - ret, frame = cap.read() - if ret: - frames.append(frame) - cap.release() + + num_frames = num_frames if num_frames > 0 else total_frames + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + for idx in range(total_frames): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_indices: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames.append(frame) frames = np.stack(frames) - frames = sample_frames_from_video(frames, num_frames) if len(frames) < num_frames: raise ValueError(f"Could not read enough frames from video file {path}" f" (expected {num_frames} frames, got {len(frames)})") diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 0b3d3f8c..f7c3f105 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -13,7 +13,7 @@ from PIL import Image from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_video_processor -from vllm.utils import PlaceholderModule, is_list_of +from vllm.utils import is_list_of from .base import MediaIO, ModalityData from .image import ImageMediaIO, ImagePlugin @@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem if TYPE_CHECKING: from vllm.config import ModelConfig -try: - import decord -except ImportError: - decord = PlaceholderModule("decord") # type: ignore[assignment] - logger = init_logger(__name__) @@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray, return sampled_frames +class VideoLoader: + + @classmethod + def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray: + raise NotImplementedError + + +class OpenCVVideoBackend(VideoLoader): + + def get_cv2_video_api(self): + import cv2.videoio_registry as vr + + api_pref = None + for backend in vr.getStreamBufferedBackends(): + if not vr.hasBackend(backend): + continue + if not vr.isBackendBuiltIn(backend): + _, abi, api = vr.getStreamBufferedBackendPluginVersion(backend) + if (abi < 1 or (abi == 1 and api < 2)): + continue + api_pref = backend + break + return api_pref + + @classmethod + def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + full_read = num_frames == -1 or total_frames_num < num_frames + if full_read: + frame_idx = list(range(0, total_frames_num)) + else: + uniform_sampled_frames = np.linspace(0, + total_frames_num - 1, + num_frames, + dtype=int) + frame_idx = uniform_sampled_frames.tolist() + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8) + + i = 0 + for idx in range(total_frames_num): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_idx: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + i += 1 + # we expect all frames loaded + assert i == num_frames + return frames + + class VideoMediaIO(MediaIO[npt.NDArray]): def __init__( @@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]): self.image_io = image_io self.num_frames = num_frames + self.video_loader = OpenCVVideoBackend def load_bytes(self, data: bytes) -> npt.NDArray: - vr = decord.VideoReader(BytesIO(data), num_threads=1) - total_frame_num = len(vr) - - num_frames = self.num_frames - if total_frame_num > num_frames: - uniform_sampled_frames = np.linspace(0, - total_frame_num - 1, - num_frames, - dtype=int) - frame_idx = uniform_sampled_frames.tolist() - else: - frame_idx = list(range(0, total_frame_num)) - - return vr.get_batch(frame_idx).asnumpy() + return self.video_loader.load_bytes(data, self.num_frames) def load_base64(self, media_type: str, data: str) -> npt.NDArray: if media_type.lower() == "video/jpeg":