[Misc] Allow using OpenCV as video IO fallback (#15055)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
b63bd14999
commit
4e5a0f6ae2
@ -30,6 +30,7 @@ msgspec
|
|||||||
gguf == 0.10.0
|
gguf == 0.10.0
|
||||||
importlib_metadata
|
importlib_metadata
|
||||||
mistral_common[opencv] >= 1.5.4
|
mistral_common[opencv] >= 1.5.4
|
||||||
|
opencv-python-headless >= 4.11.0 # required for video IO
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||||
|
@ -9,7 +9,6 @@ pytest-shard
|
|||||||
# testing utils
|
# testing utils
|
||||||
awscli
|
awscli
|
||||||
backoff # required for phi4mm test
|
backoff # required for phi4mm test
|
||||||
decord # required for video tests
|
|
||||||
einops # required for MPT, qwen-vl and Mamba
|
einops # required for MPT, qwen-vl and Mamba
|
||||||
httpx
|
httpx
|
||||||
librosa # required for audio tests
|
librosa # required for audio tests
|
||||||
@ -28,6 +27,7 @@ torchvision==0.21.0
|
|||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||||
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||||
transformers==4.50.3
|
transformers==4.50.3
|
||||||
|
@ -93,8 +93,6 @@ datasets==3.0.2
|
|||||||
# lm-eval
|
# lm-eval
|
||||||
decorator==5.1.1
|
decorator==5.1.1
|
||||||
# via librosa
|
# via librosa
|
||||||
decord==0.6.0
|
|
||||||
# via -r requirements/test.in
|
|
||||||
dill==0.3.8
|
dill==0.3.8
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
@ -276,7 +274,6 @@ numpy==1.26.4
|
|||||||
# contourpy
|
# contourpy
|
||||||
# cupy-cuda12x
|
# cupy-cuda12x
|
||||||
# datasets
|
# datasets
|
||||||
# decord
|
|
||||||
# einx
|
# einx
|
||||||
# encodec
|
# encodec
|
||||||
# evaluate
|
# evaluate
|
||||||
@ -337,8 +334,10 @@ nvidia-nvjitlink-cu12==12.4.127
|
|||||||
# torch
|
# torch
|
||||||
nvidia-nvtx-cu12==12.4.127
|
nvidia-nvtx-cu12==12.4.127
|
||||||
# via torch
|
# via torch
|
||||||
opencv-python-headless==4.10.0.84
|
opencv-python-headless==4.11.0.86
|
||||||
# via mistral-common
|
# via
|
||||||
|
# -r requirements/test.in
|
||||||
|
# mistral-common
|
||||||
packaging==24.1
|
packaging==24.1
|
||||||
# via
|
# via
|
||||||
# accelerate
|
# accelerate
|
||||||
|
2
setup.py
2
setup.py
@ -684,7 +684,7 @@ setup(
|
|||||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||||
"video": ["decord"] # Required for video processing
|
"video": [] # Kept for backwards compatibility
|
||||||
},
|
},
|
||||||
cmdclass=cmdclass,
|
cmdclass=cmdclass,
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
|
@ -10,8 +10,6 @@ import numpy.typing as npt
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.multimodal.video import sample_frames_from_video
|
|
||||||
|
|
||||||
from .base import get_cache_dir
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
|
||||||
@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
|||||||
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
frames = []
|
frames = []
|
||||||
for i in range(total_frames):
|
|
||||||
ret, frame = cap.read()
|
num_frames = num_frames if num_frames > 0 else total_frames
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
for idx in range(total_frames):
|
||||||
|
ok = cap.grab() # next img
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if idx in frame_indices: # only decompress needed
|
||||||
|
ret, frame = cap.retrieve()
|
||||||
if ret:
|
if ret:
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
cap.release()
|
|
||||||
|
|
||||||
frames = np.stack(frames)
|
frames = np.stack(frames)
|
||||||
frames = sample_frames_from_video(frames, num_frames)
|
|
||||||
if len(frames) < num_frames:
|
if len(frames) < num_frames:
|
||||||
raise ValueError(f"Could not read enough frames from video file {path}"
|
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||||
f" (expected {num_frames} frames, got {len(frames)})")
|
f" (expected {num_frames} frames, got {len(frames)})")
|
||||||
|
@ -13,7 +13,7 @@ from PIL import Image
|
|||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.processor import cached_get_video_processor
|
from vllm.transformers_utils.processor import cached_get_video_processor
|
||||||
from vllm.utils import PlaceholderModule, is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .base import MediaIO, ModalityData
|
from .base import MediaIO, ModalityData
|
||||||
from .image import ImageMediaIO, ImagePlugin
|
from .image import ImageMediaIO, ImagePlugin
|
||||||
@ -22,11 +22,6 @@ from .inputs import MultiModalKwargs, VideoItem
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
try:
|
|
||||||
import decord
|
|
||||||
except ImportError:
|
|
||||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +112,69 @@ def sample_frames_from_video(frames: npt.NDArray,
|
|||||||
return sampled_frames
|
return sampled_frames
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLoader:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(self, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class OpenCVVideoBackend(VideoLoader):
|
||||||
|
|
||||||
|
def get_cv2_video_api(self):
|
||||||
|
import cv2.videoio_registry as vr
|
||||||
|
|
||||||
|
api_pref = None
|
||||||
|
for backend in vr.getStreamBufferedBackends():
|
||||||
|
if not vr.hasBackend(backend):
|
||||||
|
continue
|
||||||
|
if not vr.isBackendBuiltIn(backend):
|
||||||
|
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||||
|
if (abi < 1 or (abi == 1 and api < 2)):
|
||||||
|
continue
|
||||||
|
api_pref = backend
|
||||||
|
break
|
||||||
|
return api_pref
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
backend = cls().get_cv2_video_api()
|
||||||
|
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError("Could not open video stream")
|
||||||
|
|
||||||
|
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
full_read = num_frames == -1 or total_frames_num < num_frames
|
||||||
|
if full_read:
|
||||||
|
frame_idx = list(range(0, total_frames_num))
|
||||||
|
else:
|
||||||
|
uniform_sampled_frames = np.linspace(0,
|
||||||
|
total_frames_num - 1,
|
||||||
|
num_frames,
|
||||||
|
dtype=int)
|
||||||
|
frame_idx = uniform_sampled_frames.tolist()
|
||||||
|
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for idx in range(total_frames_num):
|
||||||
|
ok = cap.grab() # next img
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if idx in frame_idx: # only decompress needed
|
||||||
|
ret, frame = cap.retrieve()
|
||||||
|
if ret:
|
||||||
|
frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
i += 1
|
||||||
|
# we expect all frames loaded
|
||||||
|
assert i == num_frames
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -129,22 +187,10 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
|
|||||||
|
|
||||||
self.image_io = image_io
|
self.image_io = image_io
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
|
self.video_loader = OpenCVVideoBackend
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||||
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
return self.video_loader.load_bytes(data, self.num_frames)
|
||||||
total_frame_num = len(vr)
|
|
||||||
|
|
||||||
num_frames = self.num_frames
|
|
||||||
if total_frame_num > num_frames:
|
|
||||||
uniform_sampled_frames = np.linspace(0,
|
|
||||||
total_frame_num - 1,
|
|
||||||
num_frames,
|
|
||||||
dtype=int)
|
|
||||||
frame_idx = uniform_sampled_frames.tolist()
|
|
||||||
else:
|
|
||||||
frame_idx = list(range(0, total_frame_num))
|
|
||||||
|
|
||||||
return vr.get_batch(frame_idx).asnumpy()
|
|
||||||
|
|
||||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||||
if media_type.lower() == "video/jpeg":
|
if media_type.lower() == "video/jpeg":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user