[Misc] Allow using OpenCV as video IO fallback (#15055)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
b63bd14999
commit
4e5a0f6ae2
@ -30,6 +30,7 @@ msgspec
|
||||
gguf == 0.10.0
|
||||
importlib_metadata
|
||||
mistral_common[opencv] >= 1.5.4
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
|
@ -9,7 +9,6 @@ pytest-shard
|
||||
# testing utils
|
||||
awscli
|
||||
backoff # required for phi4mm test
|
||||
decord # required for video tests
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
@ -28,6 +27,7 @@ torchvision==0.21.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||
transformers==4.50.3
|
||||
|
@ -93,8 +93,6 @@ datasets==3.0.2
|
||||
# lm-eval
|
||||
decorator==5.1.1
|
||||
# via librosa
|
||||
decord==0.6.0
|
||||
# via -r requirements/test.in
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
@ -276,7 +274,6 @@ numpy==1.26.4
|
||||
# contourpy
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
@ -337,8 +334,10 @@ nvidia-nvjitlink-cu12==12.4.127
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
# via torch
|
||||
opencv-python-headless==4.10.0.84
|
||||
# via mistral-common
|
||||
opencv-python-headless==4.11.0.86
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# mistral-common
|
||||
packaging==24.1
|
||||
# via
|
||||
# accelerate
|
||||
|
2
setup.py
2
setup.py
@ -684,7 +684,7 @@ setup(
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
"video": ["decord"] # Required for video processing
|
||||
"video": [] # Kept for backwards compatibility
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
@ -10,8 +10,6 @@ import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
|
||||
@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frames = []
|
||||
for i in range(total_frames):
|
||||
ret, frame = cap.read()
|
||||
|
||||
num_frames = num_frames if num_frames > 0 else total_frames
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
for idx in range(total_frames):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames.append(frame)
|
||||
cap.release()
|
||||
|
||||
frames = np.stack(frames)
|
||||
frames = sample_frames_from_video(frames, num_frames)
|
||||
if len(frames) < num_frames:
|
||||
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||
f" (expected {num_frames} frames, got {len(frames)})")
|
||||
|
@ -13,7 +13,7 @@ from PIL import Image
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_get_video_processor
|
||||
from vllm.utils import PlaceholderModule, is_list_of
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MediaIO, ModalityData
|
||||
from .image import ImageMediaIO, ImagePlugin
|
||||
@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray,
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
|
||||
@classmethod
|
||||
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if (abi < 1 or (abi == 1 and api < 2)):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
full_read = num_frames == -1 or total_frames_num < num_frames
|
||||
if full_read:
|
||||
frame_idx = list(range(0, total_frames_num))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frames_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||
|
||||
i = 0
|
||||
for idx in range(total_frames_num):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_idx: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
i += 1
|
||||
# we expect all frames loaded
|
||||
assert i == num_frames
|
||||
return frames
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
def __init__(
|
||||
@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
self.video_loader = OpenCVVideoBackend
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
|
||||
num_frames = self.num_frames
|
||||
if total_frame_num > num_frames:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frame_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
else:
|
||||
frame_idx = list(range(0, total_frame_num))
|
||||
|
||||
return vr.get_batch(frame_idx).asnumpy()
|
||||
return self.video_loader.load_bytes(data, self.num_frames)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
|
Loading…
x
Reference in New Issue
Block a user