Online video support for VLMs (#10020)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
litianjian 2024-11-08 04:25:59 +08:00 committed by GitHub
parent 97b8475beb
commit 28b2877d30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 598 additions and 31 deletions

View File

@ -116,6 +116,7 @@ autodoc_mock_imports = [
"soundfile", "soundfile",
"gguf", "gguf",
"lark", "lark",
"decord",
] ]
for mock_target in autodoc_mock_imports: for mock_target in autodoc_mock_imports:

View File

@ -8,6 +8,7 @@ pytest-shard
# testing utils # testing utils
awscli awscli
decord # required for video tests
einops # required for MPT, qwen-vl and Mamba einops # required for MPT, qwen-vl and Mamba
httpx httpx
librosa # required for audio tests librosa # required for audio tests
@ -15,12 +16,13 @@ opencv-python # required for video tests
peft peft
requests requests
ray[adag]==2.35 ray[adag]==2.35
sentence-transformers # required for embedding sentence-transformers # required for embedding tests
soundfile # required for audio test soundfile # required for audio tests
timm # required for internvl test timm # required for internvl test
torch==2.5.1 torch==2.5.1
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.4.4 # required for pixtral test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test lm-eval[api]==0.4.4 # required for model evaluation test

View File

@ -1,8 +1,8 @@
# #
# This file is autogenerated by pip-compile with Python 3.12 # This file is autogenerated by pip-compile with Python 3.9
# by the following command: # by the following command:
# #
# pip-compile --output-file=requirements-test.txt requirements-test.in # pip-compile requirements-test.in
# #
absl-py==2.1.0 absl-py==2.1.0
# via rouge-score # via rouge-score
@ -28,6 +28,10 @@ anyio==4.6.2.post1
# via httpx # via httpx
argcomplete==3.5.1 argcomplete==3.5.1
# via datamodel-code-generator # via datamodel-code-generator
async-timeout==4.0.3
# via
# aiohttp
# redis
attrs==24.2.0 attrs==24.2.0
# via # via
# aiohttp # aiohttp
@ -90,6 +94,8 @@ datasets==3.0.2
# lm-eval # lm-eval
decorator==5.1.1 decorator==5.1.1
# via librosa # via librosa
decord==0.6.0
# via -r requirements-test.in
dill==0.3.8 dill==0.3.8
# via # via
# datasets # datasets
@ -106,6 +112,10 @@ email-validator==2.2.0
# via pydantic # via pydantic
evaluate==0.4.3 evaluate==0.4.3
# via lm-eval # via lm-eval
exceptiongroup==1.2.2
# via
# anyio
# pytest
fastrlock==0.8.2 fastrlock==0.8.2
# via cupy-cuda12x # via cupy-cuda12x
filelock==3.16.1 filelock==3.16.1
@ -156,6 +166,8 @@ idna==3.10
# httpx # httpx
# requests # requests
# yarl # yarl
importlib-resources==6.4.5
# via matplotlib
inflect==5.6.2 inflect==5.6.2
# via datamodel-code-generator # via datamodel-code-generator
iniconfig==2.0.0 iniconfig==2.0.0
@ -178,7 +190,9 @@ joblib==1.4.2
jsonlines==4.0.0 jsonlines==4.0.0
# via lm-eval # via lm-eval
jsonschema==4.23.0 jsonschema==4.23.0
# via ray # via
# mistral-common
# ray
jsonschema-specifications==2024.10.1 jsonschema-specifications==2024.10.1
# via jsonschema # via jsonschema
kiwisolver==1.4.7 kiwisolver==1.4.7
@ -204,6 +218,10 @@ mbstrdecoder==1.1.3
# dataproperty # dataproperty
# pytablewriter # pytablewriter
# typepy # typepy
mistral-common[opencv]==1.4.4
# via
# -r requirements-test.in
# mistral-common
more-itertools==10.5.0 more-itertools==10.5.0
# via lm-eval # via lm-eval
mpmath==1.3.0 mpmath==1.3.0
@ -238,12 +256,15 @@ numpy==1.26.4
# contourpy # contourpy
# cupy-cuda12x # cupy-cuda12x
# datasets # datasets
# decord
# evaluate # evaluate
# librosa # librosa
# matplotlib # matplotlib
# mistral-common
# numba # numba
# numexpr # numexpr
# opencv-python # opencv-python
# opencv-python-headless
# pandas # pandas
# peft # peft
# rouge-score # rouge-score
@ -288,6 +309,8 @@ nvidia-nvtx-cu12==12.4.127
# via torch # via torch
opencv-python==4.10.0.84 opencv-python==4.10.0.84
# via -r requirements-test.in # via -r requirements-test.in
opencv-python-headless==4.10.0.84
# via mistral-common
packaging==24.1 packaging==24.1
# via # via
# accelerate # accelerate
@ -317,9 +340,10 @@ peft==0.13.2
# via # via
# -r requirements-test.in # -r requirements-test.in
# lm-eval # lm-eval
pillow==11.0.0 pillow==10.4.0
# via # via
# matplotlib # matplotlib
# mistral-common
# sentence-transformers # sentence-transformers
# torchvision # torchvision
platformdirs==4.3.6 platformdirs==4.3.6
@ -354,7 +378,9 @@ pybind11==2.13.6
pycparser==2.22 pycparser==2.22
# via cffi # via cffi
pydantic[email]==2.9.2 pydantic[email]==2.9.2
# via datamodel-code-generator # via
# datamodel-code-generator
# mistral-common
pydantic-core==2.23.4 pydantic-core==2.23.4
# via pydantic # via pydantic
pyparsing==3.2.0 pyparsing==3.2.0
@ -420,6 +446,7 @@ requests==2.32.3
# evaluate # evaluate
# huggingface-hub # huggingface-hub
# lm-eval # lm-eval
# mistral-common
# pooch # pooch
# ray # ray
# tiktoken # tiktoken
@ -456,6 +483,8 @@ scipy==1.13.1
# sentence-transformers # sentence-transformers
sentence-transformers==3.2.1 sentence-transformers==3.2.1
# via -r requirements-test.in # via -r requirements-test.in
sentencepiece==0.2.0
# via mistral-common
six==1.16.0 six==1.16.0
# via # via
# python-dateutil # python-dateutil
@ -486,12 +515,20 @@ tensorizer==2.9.0
# via -r requirements-test.in # via -r requirements-test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
tiktoken==0.8.0 tiktoken==0.7.0
# via lm-eval # via
# lm-eval
# mistral-common
timm==1.0.11 timm==1.0.11
# via -r requirements-test.in # via -r requirements-test.in
tokenizers==0.20.1 tokenizers==0.20.1
# via transformers # via transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.0.2
# via
# black
# pytest
torch==2.5.1 torch==2.5.1
# via # via
# -r requirements-test.in # -r requirements-test.in
@ -535,8 +572,12 @@ typepy[datetime]==1.3.2
# tabledata # tabledata
typing-extensions==4.12.2 typing-extensions==4.12.2
# via # via
# anyio
# black
# huggingface-hub # huggingface-hub
# librosa # librosa
# mistral-common
# multidict
# pydantic # pydantic
# pydantic-core # pydantic-core
# torch # torch
@ -554,6 +595,8 @@ xxhash==3.5.0
# evaluate # evaluate
yarl==1.17.1 yarl==1.17.1
# via aiohttp # via aiohttp
zipp==3.20.2
# via importlib-resources
zstandard==0.23.0 zstandard==0.23.0
# via lm-eval # via lm-eval

View File

@ -554,7 +554,8 @@ setup(
ext_modules=ext_modules, ext_modules=ext_modules,
extras_require={ extras_require={
"tensorizer": ["tensorizer>=2.9.0"], "tensorizer": ["tensorizer>=2.9.0"],
"audio": ["librosa", "soundfile"] # Required for audio processing "audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing
}, },
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
package_data=package_data, package_data=package_data,

View File

@ -0,0 +1,345 @@
from typing import Dict, List
import openai
import pytest
import pytest_asyncio
from vllm.multimodal.utils import encode_video_base64, fetch_video
from ...utils import RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
MAXIMUM_VIDEOS = 4
TEST_VIDEO_URLS = [
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4",
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ElephantsDream.mp4",
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerBlazes.mp4",
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4",
]
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",
"32768",
"--max-num-seqs",
"2",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"video={MAXIMUM_VIDEOS}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.fixture(scope="session")
def base64_encoded_video() -> Dict[str, str]:
return {
video_url: encode_video_base64(fetch_video(video_url))
for video_url in TEST_VIDEO_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_single_chat_session_video(client: openai.AsyncOpenAI,
model_name: str, video_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "video_url",
"video_url": {
"url": video_url
}
},
{
"type": "text",
"text": "What's in this video?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
video_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "video_url",
"video_url": {
"url": video_url
}
},
{
"type": "text",
"text": "What's in this video?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_single_chat_session_video_base64encoded(
client: openai.AsyncOpenAI, model_name: str, video_url: str,
base64_encoded_video: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "video_url",
"video_url": {
"url":
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
}
},
{
"type": "text",
"text": "What's in this video?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_single_chat_session_video_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, video_url: str,
base64_encoded_video: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "video_url",
"video_url": {
"url":
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
}
},
{
"type": "text",
"text": "What's in this video?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_completion_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
async def test_chat_streaming_video(client: openai.AsyncOpenAI,
model_name: str, video_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "video_url",
"video_url": {
"url": video_url
}
},
{
"type": "text",
"text": "What's in this video?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason
# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"video_urls",
[TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))])
async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str,
video_urls: List[str]):
messages = [{
"role":
"user",
"content": [
*({
"type": "video_url",
"video_url": {
"url": video_url
}
} for video_url in video_urls),
{
"type": "text",
"text": "What's in this video?"
},
],
}]
if len(video_urls) > MAXIMUM_VIDEOS:
with pytest.raises(openai.BadRequestError): # test multi-video input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
else:
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0

View File

@ -35,7 +35,7 @@ def download_video_asset(filename: str) -> str:
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
cv2 = try_import_video_packages() cv2, _ = try_import_video_packages()
cap = cv2.VideoCapture(path) cap = cv2.VideoCapture(path)
if not cap.isOpened(): if not cap.isOpened():
@ -59,7 +59,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
def video_to_pil_images_list(path: str, def video_to_pil_images_list(path: str,
num_frames: int = -1) -> List[Image.Image]: num_frames: int = -1) -> List[Image.Image]:
cv2 = try_import_video_packages() cv2, _ = try_import_video_packages()
frames = video_to_ndarrays(path, num_frames) frames = video_to_ndarrays(path, num_frames)
return [ return [
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

View File

@ -30,7 +30,9 @@ from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio, from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image, async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image) async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -51,6 +53,20 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the video or a data URL with base64 encoded video data.
"""
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
"""The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url. """A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented. This is supported by OpenAI API, although it is not documented.
@ -74,11 +90,23 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
audio_url: Required[str] audio_url: Required[str]
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"video_url": "https://example.com/video.mp4"
}
"""
video_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam, ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str] CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -201,6 +229,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video": elif modality == "video":
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.video_token_index)
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
@ -291,6 +322,10 @@ class BaseMultiModalContentParser(ABC):
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def parse_video(self, video_url: str) -> None:
raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser): class MultiModalContentParser(BaseMultiModalContentParser):
@ -313,6 +348,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser): class AsyncMultiModalContentParser(BaseMultiModalContentParser):
@ -336,6 +377,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]): def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid.""" """Raises if the provided chat template appears invalid."""
@ -416,6 +463,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam) _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
@ -428,6 +476,8 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal": "refusal":
lambda part: _RefusalParser(part).get("refusal", ""), lambda part: _RefusalParser(part).get("refusal", ""),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
} }
@ -472,7 +522,10 @@ def _parse_chat_message_content_mm_part(
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part) part)
return "audio_url", audio_params.get("audio_url", "") return "audio_url", audio_params.get("audio_url", "")
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found. # Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.") raise ValueError("Missing 'type' field in multimodal part.")
@ -482,7 +535,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url") "audio_url", "video_url")
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
@ -542,7 +595,7 @@ def _parse_chat_message_content_part(
# Handle structured dictionary parts # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part) part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url but # if part_type is text/refusal/image_url/audio_url/video_url but
# content is empty, log a warning and skip # content is empty, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning( logger.warning(
@ -561,6 +614,10 @@ def _parse_chat_message_content_part(
mm_parser.parse_audio(content) mm_parser.parse_audio(content)
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
if part_type == "video_url":
mm_parser.parse_video(content)
return {'type': 'video'} if wrap_dicts else None
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")

View File

@ -49,7 +49,8 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 15
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
@ -376,10 +377,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching videos when serving multimodal models
# Default is 15 seconds
"VLLM_VIDEO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")),
# Timeout for fetching audio when serving multimodal models # Timeout for fetching audio when serving multimodal models
# Default is 5 seconds # Default is 10 seconds
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.

View File

@ -341,7 +341,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
get_llava_onevision_video_tokens(ctx, num_frames)) get_llava_onevision_video_tokens(ctx, num_frames))
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
@ -350,7 +350,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
) )
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
else: else:
raise TypeError(f"Invalid video type: {type(video_data)}") raise TypeError(f"Invalid video type: {type(video_data)}")

View File

@ -136,6 +136,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio item(s) and corresponding sampling rate(s).""" """The input audio item(s) and corresponding sampling rate(s)."""
video: MultiModalData[Tuple[np.ndarray]]
"""The input video(s)."""
MultiModalDataDict = Union[MultiModalDataBuiltins, MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]] Mapping[str, MultiModalData[object]]]

View File

@ -8,8 +8,8 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
import vllm.envs as envs
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@ -80,7 +80,9 @@ def fetch_image(image_url: str,
""" """
if image_url.startswith('http'): if image_url.startswith('http'):
image_raw = global_http_connection.get_bytes( image_raw = global_http_connection.get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw) image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): elif image_url.startswith('data:image'):
@ -105,7 +107,9 @@ async def async_fetch_image(image_url: str,
""" """
if image_url.startswith('http'): if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes( image_raw = await global_http_connection.async_get_bytes(
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw) image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): elif image_url.startswith('data:image'):
@ -119,6 +123,85 @@ async def async_fetch_image(image_url: str,
return image.convert(image_mode) return image.convert(image_mode)
def _load_video_frames_from_bytes(b: bytes):
frame = Image.open(BytesIO(b))
return np.array(frame)
def load_video_frames_from_base64(frame: Union[bytes, str]):
"""Load frame from base64 format."""
return _load_video_frames_from_bytes(base64.b64decode(frame))
def _load_video_from_bytes(b: bytes, num_frames: int = 32):
_, decord = try_import_video_packages()
video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = [i for i in range(0, total_frame_num)]
frames = vr.get_batch(frame_idx).asnumpy()
return frames
def _load_video_from_data_url(video_url: str):
# Only split once and assume the second part is the base64 encoded image
frames_base64 = video_url.split(",")[1:]
return np.stack([
load_video_frames_from_base64(frame_base64)
for frame_base64 in frames_base64
])
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = global_http_connection.get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
async def async_fetch_video(video_url: str,
*,
num_frames: int = 32) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = await global_http_connection.async_get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
def try_import_audio_packages() -> Tuple[Any, Any]: def try_import_audio_packages() -> Tuple[Any, Any]:
try: try:
import librosa import librosa
@ -137,7 +220,9 @@ def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
if audio_url.startswith("http"): if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes( audio_bytes = global_http_connection.get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"): elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1) _, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64) audio_bytes = base64.b64decode(audio_base64)
@ -157,7 +242,9 @@ async def async_fetch_audio(
if audio_url.startswith("http"): if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes( audio_bytes = await global_http_connection.async_get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"): elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1) _, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64) audio_bytes = base64.b64decode(audio_base64)
@ -182,6 +269,11 @@ def get_and_parse_image(
return {"image": image} return {"image": image}
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = fetch_video(video_url)
return {"video": video}
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url) audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)} return {"audio": (audio, sr)}
@ -196,6 +288,11 @@ async def async_get_and_parse_image(
return {"image": image} return {"image": image}
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = await async_fetch_video(video_url)
return {"video": video}
def encode_audio_base64( def encode_audio_base64(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
@ -246,14 +343,15 @@ def rescale_image_size(image: Image.Image,
def try_import_video_packages() -> Any: def try_import_video_packages() -> Any:
try: try:
import cv2 import cv2
import decord
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install vllm[video] for video support.") from None "Please install vllm[video] for video support.") from None
return cv2 return cv2, decord
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray: def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
cv2 = try_import_video_packages() cv2, _ = try_import_video_packages()
num_frames, _, _, channels = frames.shape num_frames, _, _, channels = frames.shape
new_height, new_width = size new_height, new_width = size
@ -284,6 +382,15 @@ def sample_frames_from_video(frames: npt.NDArray,
return sampled_frames return sampled_frames
def encode_video_base64(frames: npt.NDArray):
base64_frames = []
frames_list = [frames[i] for i in range(frames.shape[0])]
for frame in frames_list:
img_base64 = encode_image_base64(Image.fromarray(frame))
base64_frames.append(img_base64)
return ",".join(base64_frames)
# Utilities for input processors # Utilities for input processors
_T = TypeVar("_T", str, int) _T = TypeVar("_T", str, int)

View File

@ -7,6 +7,7 @@ from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs from .base import MultiModalData, MultiModalInputs
from .image import ImagePlugin from .image import ImagePlugin
@ -60,7 +61,7 @@ class VideoPlugin(ImagePlugin):
if isinstance(data, list) and len(data) == 1: if isinstance(data, list) and len(data) == 1:
data = data[0] data = data[0]
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor( video_processor = self._get_hf_video_processor(
model_config, model_config,
mm_processor_kwargs, mm_processor_kwargs,