[Misc] Allow using OpenCV as video IO fallback (#15055)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Isotr0py 2025-04-01 23:55:13 +08:00 committed by GitHub
parent b63bd14999
commit 4e5a0f6ae2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 35 deletions

View File

@ -30,6 +30,7 @@ msgspec
gguf == 0.10.0 gguf == 0.10.0
importlib_metadata importlib_metadata
mistral_common[opencv] >= 1.5.4 mistral_common[opencv] >= 1.5.4
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12

View File

@ -9,7 +9,6 @@ pytest-shard
# testing utils # testing utils
awscli awscli
backoff # required for phi4mm test backoff # required for phi4mm test
decord # required for video tests
einops # required for MPT, qwen-vl and Mamba einops # required for MPT, qwen-vl and Mamba
httpx httpx
librosa # required for audio tests librosa # required for audio tests
@ -28,6 +27,7 @@ torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test mistral_common[opencv] >= 1.5.4 # required for pixtral test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test lm-eval[api]==0.4.4 # required for model evaluation test
transformers==4.50.3 transformers==4.50.3

View File

@ -93,8 +93,6 @@ datasets==3.0.2
# lm-eval # lm-eval
decorator==5.1.1 decorator==5.1.1
# via librosa # via librosa
decord==0.6.0
# via -r requirements/test.in
dill==0.3.8 dill==0.3.8
# via # via
# datasets # datasets
@ -276,7 +274,6 @@ numpy==1.26.4
# contourpy # contourpy
# cupy-cuda12x # cupy-cuda12x
# datasets # datasets
# decord
# einx # einx
# encodec # encodec
# evaluate # evaluate
@ -337,8 +334,10 @@ nvidia-nvjitlink-cu12==12.4.127
# torch # torch
nvidia-nvtx-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127
# via torch # via torch
opencv-python-headless==4.10.0.84 opencv-python-headless==4.11.0.86
# via mistral-common # via
# -r requirements/test.in
# mistral-common
packaging==24.1 packaging==24.1
# via # via
# accelerate # accelerate

View File

@ -684,7 +684,7 @@ setup(
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing "audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing "video": [] # Kept for backwards compatibility
}, },
cmdclass=cmdclass, cmdclass=cmdclass,
package_data=package_data, package_data=package_data,

View File

@ -10,8 +10,6 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from vllm.multimodal.video import sample_frames_from_video
from .base import get_cache_dir from .base import get_cache_dir
@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = [] frames = []
for i in range(total_frames):
ret, frame = cap.read() num_frames = num_frames if num_frames > 0 else total_frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for idx in range(total_frames):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret: if ret:
frames.append(frame) frames.append(frame)
cap.release()
frames = np.stack(frames) frames = np.stack(frames)
frames = sample_frames_from_video(frames, num_frames)
if len(frames) < num_frames: if len(frames) < num_frames:
raise ValueError(f"Could not read enough frames from video file {path}" raise ValueError(f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})") f" (expected {num_frames} frames, got {len(frames)})")

View File

@ -13,7 +13,7 @@ from PIL import Image
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_video_processor from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.utils import PlaceholderModule, is_list_of from vllm.utils import is_list_of
from .base import MediaIO, ModalityData from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin from .image import ImageMediaIO, ImagePlugin
@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray,
return sampled_frames return sampled_frames
class VideoLoader:
@classmethod
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
raise NotImplementedError
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
import cv2.videoio_registry as vr
api_pref = None
for backend in vr.getStreamBufferedBackends():
if not vr.hasBackend(backend):
continue
if not vr.isBackendBuiltIn(backend):
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
if (abi < 1 or (abi == 1 and api < 2)):
continue
api_pref = backend
break
return api_pref
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
import cv2
backend = cls().get_cv2_video_api()
cap = cv2.VideoCapture(BytesIO(data), backend, [])
if not cap.isOpened():
raise ValueError("Could not open video stream")
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
full_read = num_frames == -1 or total_frames_num < num_frames
if full_read:
frame_idx = list(range(0, total_frames_num))
else:
uniform_sampled_frames = np.linspace(0,
total_frames_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
i = 0
for idx in range(total_frames_num):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_idx: # only decompress needed
ret, frame = cap.retrieve()
if ret:
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
i += 1
# we expect all frames loaded
assert i == num_frames
return frames
class VideoMediaIO(MediaIO[npt.NDArray]): class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__( def __init__(
@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
self.image_io = image_io self.image_io = image_io
self.num_frames = num_frames self.num_frames = num_frames
self.video_loader = OpenCVVideoBackend
def load_bytes(self, data: bytes) -> npt.NDArray: def load_bytes(self, data: bytes) -> npt.NDArray:
vr = decord.VideoReader(BytesIO(data), num_threads=1) return self.video_loader.load_bytes(data, self.num_frames)
total_frame_num = len(vr)
num_frames = self.num_frames
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = list(range(0, total_frame_num))
return vr.get_batch(frame_idx).asnumpy()
def load_base64(self, media_type: str, data: str) -> npt.NDArray: def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg": if media_type.lower() == "video/jpeg":