Online video support for VLMs (#10020)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: litianjian <litianjian@bytedance.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
97b8475beb
commit
28b2877d30
@ -116,6 +116,7 @@ autodoc_mock_imports = [
|
||||
"soundfile",
|
||||
"gguf",
|
||||
"lark",
|
||||
"decord",
|
||||
]
|
||||
|
||||
for mock_target in autodoc_mock_imports:
|
||||
|
@ -8,6 +8,7 @@ pytest-shard
|
||||
|
||||
# testing utils
|
||||
awscli
|
||||
decord # required for video tests
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
@ -15,12 +16,13 @@ opencv-python # required for video tests
|
||||
peft
|
||||
requests
|
||||
ray[adag]==2.35
|
||||
sentence-transformers # required for embedding
|
||||
soundfile # required for audio test
|
||||
sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
timm # required for internvl test
|
||||
torch==2.5.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.4.4 # required for pixtral test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# This file is autogenerated by pip-compile with Python 3.9
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-test.txt requirements-test.in
|
||||
# pip-compile requirements-test.in
|
||||
#
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
@ -28,6 +28,10 @@ anyio==4.6.2.post1
|
||||
# via httpx
|
||||
argcomplete==3.5.1
|
||||
# via datamodel-code-generator
|
||||
async-timeout==4.0.3
|
||||
# via
|
||||
# aiohttp
|
||||
# redis
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -90,6 +94,8 @@ datasets==3.0.2
|
||||
# lm-eval
|
||||
decorator==5.1.1
|
||||
# via librosa
|
||||
decord==0.6.0
|
||||
# via -r requirements-test.in
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
@ -106,6 +112,10 @@ email-validator==2.2.0
|
||||
# via pydantic
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
exceptiongroup==1.2.2
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
fastrlock==0.8.2
|
||||
# via cupy-cuda12x
|
||||
filelock==3.16.1
|
||||
@ -156,6 +166,8 @@ idna==3.10
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
importlib-resources==6.4.5
|
||||
# via matplotlib
|
||||
inflect==5.6.2
|
||||
# via datamodel-code-generator
|
||||
iniconfig==2.0.0
|
||||
@ -178,7 +190,9 @@ joblib==1.4.2
|
||||
jsonlines==4.0.0
|
||||
# via lm-eval
|
||||
jsonschema==4.23.0
|
||||
# via ray
|
||||
# via
|
||||
# mistral-common
|
||||
# ray
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
kiwisolver==1.4.7
|
||||
@ -204,6 +218,10 @@ mbstrdecoder==1.1.3
|
||||
# dataproperty
|
||||
# pytablewriter
|
||||
# typepy
|
||||
mistral-common[opencv]==1.4.4
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# mistral-common
|
||||
more-itertools==10.5.0
|
||||
# via lm-eval
|
||||
mpmath==1.3.0
|
||||
@ -238,12 +256,15 @@ numpy==1.26.4
|
||||
# contourpy
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# evaluate
|
||||
# librosa
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# numba
|
||||
# numexpr
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# peft
|
||||
# rouge-score
|
||||
@ -288,6 +309,8 @@ nvidia-nvtx-cu12==12.4.127
|
||||
# via torch
|
||||
opencv-python==4.10.0.84
|
||||
# via -r requirements-test.in
|
||||
opencv-python-headless==4.10.0.84
|
||||
# via mistral-common
|
||||
packaging==24.1
|
||||
# via
|
||||
# accelerate
|
||||
@ -317,9 +340,10 @@ peft==0.13.2
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# lm-eval
|
||||
pillow==11.0.0
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# sentence-transformers
|
||||
# torchvision
|
||||
platformdirs==4.3.6
|
||||
@ -354,7 +378,9 @@ pybind11==2.13.6
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic[email]==2.9.2
|
||||
# via datamodel-code-generator
|
||||
# via
|
||||
# datamodel-code-generator
|
||||
# mistral-common
|
||||
pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pyparsing==3.2.0
|
||||
@ -420,6 +446,7 @@ requests==2.32.3
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# pooch
|
||||
# ray
|
||||
# tiktoken
|
||||
@ -456,6 +483,8 @@ scipy==1.13.1
|
||||
# sentence-transformers
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
# via mistral-common
|
||||
six==1.16.0
|
||||
# via
|
||||
# python-dateutil
|
||||
@ -486,12 +515,20 @@ tensorizer==2.9.0
|
||||
# via -r requirements-test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tiktoken==0.8.0
|
||||
# via lm-eval
|
||||
tiktoken==0.7.0
|
||||
# via
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
# via -r requirements-test.in
|
||||
tokenizers==0.20.1
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via datamodel-code-generator
|
||||
tomli==2.0.2
|
||||
# via
|
||||
# black
|
||||
# pytest
|
||||
torch==2.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
@ -535,8 +572,12 @@ typepy[datetime]==1.3.2
|
||||
# tabledata
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# black
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
@ -554,6 +595,8 @@ xxhash==3.5.0
|
||||
# evaluate
|
||||
yarl==1.17.1
|
||||
# via aiohttp
|
||||
zipp==3.20.2
|
||||
# via importlib-resources
|
||||
zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
||||
|
3
setup.py
3
setup.py
@ -554,7 +554,8 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
"audio": ["librosa", "soundfile"] # Required for audio processing
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
"video": ["decord"] # Required for video processing
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
||||
package_data=package_data,
|
||||
|
345
tests/entrypoints/openai/test_video.py
Normal file
345
tests/entrypoints/openai/test_video.py
Normal file
@ -0,0 +1,345 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import encode_video_base64, fetch_video
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
|
||||
MAXIMUM_VIDEOS = 4
|
||||
|
||||
TEST_VIDEO_URLS = [
|
||||
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4",
|
||||
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ElephantsDream.mp4",
|
||||
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerBlazes.mp4",
|
||||
"http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"generate",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"video={MAXIMUM_VIDEOS}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_video() -> Dict[str, str]:
|
||||
return {
|
||||
video_url: encode_video_base64(fetch_video(video_url))
|
||||
for video_url in TEST_VIDEO_URLS
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
||||
model_name: str, video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
|
||||
# test multi-turn dialogue
|
||||
messages.append({"role": "user", "content": "express your result in json"})
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||
base64_encoded_video: Dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url":
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
|
||||
# test multi-turn dialogue
|
||||
messages.append({"role": "user", "content": "express your result in json"})
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||
base64_encoded_video: Dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url":
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_chat_streaming_video(client: openai.AsyncOpenAI,
|
||||
model_name: str, video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
output = chat_completion.choices[0].message.content
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
|
||||
# test streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role:
|
||||
assert delta.role == "assistant"
|
||||
if delta.content:
|
||||
chunks.append(delta.content)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == stop_reason
|
||||
assert delta.content
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize(
|
||||
"video_urls",
|
||||
[TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))])
|
||||
async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
video_urls: List[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
} for video_url in video_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
if len(video_urls) > MAXIMUM_VIDEOS:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-video input
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
else:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
@ -35,7 +35,7 @@ def download_video_asset(filename: str) -> str:
|
||||
|
||||
|
||||
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
cv2 = try_import_video_packages()
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
@ -59,7 +59,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
|
||||
def video_to_pil_images_list(path: str,
|
||||
num_frames: int = -1) -> List[Image.Image]:
|
||||
cv2 = try_import_video_packages()
|
||||
cv2, _ = try_import_video_packages()
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [
|
||||
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
@ -30,7 +30,9 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
async_get_and_parse_video,
|
||||
get_and_parse_audio, get_and_parse_image,
|
||||
get_and_parse_video)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@ -51,6 +53,20 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class VideoURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""
|
||||
Either a URL of the video or a data URL with base64 encoded video data.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
video_url: Required[VideoURL]
|
||||
|
||||
type: Required[Literal["video_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
"""A simpler version of the param that only accepts a plain image_url.
|
||||
This is supported by OpenAI API, although it is not documented.
|
||||
@ -74,11 +90,23 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
||||
audio_url: Required[str]
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
"""A simpler version of the param that only accepts a plain audio_url.
|
||||
|
||||
Example:
|
||||
{
|
||||
"video_url": "https://example.com/video.mp4"
|
||||
}
|
||||
"""
|
||||
video_url: Required[str]
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
CustomChatCompletionContentSimpleAudioParam, str]
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam, str]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@ -201,6 +229,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
elif modality == "video":
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.video_token_index)
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
@ -291,6 +322,10 @@ class BaseMultiModalContentParser(ABC):
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
@ -313,6 +348,12 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = get_and_parse_video(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
@ -336,6 +377,12 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = async_get_and_parse_video(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
@ -416,6 +463,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
@ -428,6 +476,8 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
||||
"refusal":
|
||||
lambda part: _RefusalParser(part).get("refusal", ""),
|
||||
"video_url":
|
||||
lambda part: _VideoParser(part).get("video_url", {}).get("url", ""),
|
||||
}
|
||||
|
||||
|
||||
@ -472,7 +522,10 @@ def _parse_chat_message_content_mm_part(
|
||||
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
|
||||
part)
|
||||
return "audio_url", audio_params.get("audio_url", "")
|
||||
|
||||
if part.get("video_url") is not None:
|
||||
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
|
||||
part)
|
||||
return "video_url", video_params.get("video_url", "")
|
||||
# Raise an error if no 'type' or direct URL is found.
|
||||
raise ValueError("Missing 'type' field in multimodal part.")
|
||||
|
||||
@ -482,7 +535,7 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||
"audio_url")
|
||||
"audio_url", "video_url")
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
@ -542,7 +595,7 @@ def _parse_chat_message_content_part(
|
||||
# Handle structured dictionary parts
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
|
||||
# if part_type is text/refusal/image_url/audio_url but
|
||||
# if part_type is text/refusal/image_url/audio_url/video_url but
|
||||
# content is empty, log a warning and skip
|
||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
|
||||
logger.warning(
|
||||
@ -561,6 +614,10 @@ def _parse_chat_message_content_part(
|
||||
mm_parser.parse_audio(content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "video_url":
|
||||
mm_parser.parse_video(content)
|
||||
return {'type': 'video'} if wrap_dicts else None
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
|
||||
|
12
vllm/envs.py
12
vllm/envs.py
@ -49,7 +49,8 @@ if TYPE_CHECKING:
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 15
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
MAX_JOBS: Optional[str] = None
|
||||
NVCC_THREADS: Optional[str] = None
|
||||
@ -376,10 +377,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_IMAGE_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
||||
|
||||
# Timeout for fetching videos when serving multimodal models
|
||||
# Default is 15 seconds
|
||||
"VLLM_VIDEO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")),
|
||||
|
||||
# Timeout for fetching audio when serving multimodal models
|
||||
# Default is 5 seconds
|
||||
# Default is 10 seconds
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Path to the XLA persistent cache directory.
|
||||
# Only used for XLA devices such as TPUs.
|
||||
|
@ -341,7 +341,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
get_llava_onevision_video_tokens(ctx, num_frames))
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
@ -350,7 +350,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"video": ranges})
|
||||
else:
|
||||
raise TypeError(f"Invalid video type: {type(video_data)}")
|
||||
|
||||
|
@ -136,6 +136,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
|
||||
"""The input audio item(s) and corresponding sampling rate(s)."""
|
||||
|
||||
video: MultiModalData[Tuple[np.ndarray]]
|
||||
"""The input video(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict = Union[MultiModalDataBuiltins,
|
||||
Mapping[str, MultiModalData[object]]]
|
||||
|
@ -8,8 +8,8 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
@ -80,7 +80,9 @@ def fetch_image(image_url: str,
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
image_raw = global_http_connection.get_bytes(
|
||||
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
image_url,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
@ -105,7 +107,9 @@ async def async_fetch_image(image_url: str,
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
image_raw = await global_http_connection.async_get_bytes(
|
||||
image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT)
|
||||
image_url,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
@ -119,6 +123,85 @@ async def async_fetch_image(image_url: str,
|
||||
return image.convert(image_mode)
|
||||
|
||||
|
||||
def _load_video_frames_from_bytes(b: bytes):
|
||||
frame = Image.open(BytesIO(b))
|
||||
return np.array(frame)
|
||||
|
||||
|
||||
def load_video_frames_from_base64(frame: Union[bytes, str]):
|
||||
"""Load frame from base64 format."""
|
||||
return _load_video_frames_from_bytes(base64.b64decode(frame))
|
||||
|
||||
|
||||
def _load_video_from_bytes(b: bytes, num_frames: int = 32):
|
||||
_, decord = try_import_video_packages()
|
||||
|
||||
video_path = BytesIO(b)
|
||||
vr = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
|
||||
if total_frame_num > num_frames:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frame_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
else:
|
||||
frame_idx = [i for i in range(0, total_frame_num)]
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _load_video_from_data_url(video_url: str):
|
||||
# Only split once and assume the second part is the base64 encoded image
|
||||
frames_base64 = video_url.split(",")[1:]
|
||||
return np.stack([
|
||||
load_video_frames_from_base64(frame_base64)
|
||||
for frame_base64 in frames_base64
|
||||
])
|
||||
|
||||
|
||||
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
|
||||
"""
|
||||
Load video from a HTTP or base64 data URL.
|
||||
"""
|
||||
if video_url.startswith('http') or video_url.startswith('https'):
|
||||
video_raw = global_http_connection.get_bytes(
|
||||
video_url,
|
||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
video = _load_video_from_bytes(video_raw, num_frames)
|
||||
elif video_url.startswith('data:video'):
|
||||
video = _load_video_from_data_url(video_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
||||
"with either 'data:video' or 'http'.")
|
||||
return video
|
||||
|
||||
|
||||
async def async_fetch_video(video_url: str,
|
||||
*,
|
||||
num_frames: int = 32) -> npt.NDArray:
|
||||
"""
|
||||
Asynchronously load video from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if video_url.startswith('http') or video_url.startswith('https'):
|
||||
video_raw = await global_http_connection.async_get_bytes(
|
||||
video_url,
|
||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
video = _load_video_from_bytes(video_raw, num_frames)
|
||||
elif video_url.startswith('data:video'):
|
||||
video = _load_video_from_data_url(video_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
||||
"with either 'data:video' or 'http'.")
|
||||
return video
|
||||
|
||||
|
||||
def try_import_audio_packages() -> Tuple[Any, Any]:
|
||||
try:
|
||||
import librosa
|
||||
@ -137,7 +220,9 @@ def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = global_http_connection.get_bytes(
|
||||
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
|
||||
audio_url,
|
||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
elif audio_url.startswith("data:audio"):
|
||||
_, audio_base64 = audio_url.split(",", 1)
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
@ -157,7 +242,9 @@ async def async_fetch_audio(
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = await global_http_connection.async_get_bytes(
|
||||
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
|
||||
audio_url,
|
||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
elif audio_url.startswith("data:audio"):
|
||||
_, audio_base64 = audio_url.split(",", 1)
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
@ -182,6 +269,11 @@ def get_and_parse_image(
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
||||
video = fetch_video(video_url)
|
||||
return {"video": video}
|
||||
|
||||
|
||||
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||
audio, sr = await async_fetch_audio(audio_url)
|
||||
return {"audio": (audio, sr)}
|
||||
@ -196,6 +288,11 @@ async def async_get_and_parse_image(
|
||||
return {"image": image}
|
||||
|
||||
|
||||
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
||||
video = await async_fetch_video(video_url)
|
||||
return {"video": video}
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
@ -246,14 +343,15 @@ def rescale_image_size(image: Image.Image,
|
||||
def try_import_video_packages() -> Any:
|
||||
try:
|
||||
import cv2
|
||||
import decord
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install vllm[video] for video support.") from None
|
||||
return cv2
|
||||
return cv2, decord
|
||||
|
||||
|
||||
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
|
||||
cv2 = try_import_video_packages()
|
||||
cv2, _ = try_import_video_packages()
|
||||
|
||||
num_frames, _, _, channels = frames.shape
|
||||
new_height, new_width = size
|
||||
@ -284,6 +382,15 @@ def sample_frames_from_video(frames: npt.NDArray,
|
||||
return sampled_frames
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray):
|
||||
base64_frames = []
|
||||
frames_list = [frames[i] for i in range(frames.shape[0])]
|
||||
for frame in frames_list:
|
||||
img_base64 = encode_image_base64(Image.fromarray(frame))
|
||||
base64_frames.append(img_base64)
|
||||
return ",".join(base64_frames)
|
||||
|
||||
|
||||
# Utilities for input processors
|
||||
_T = TypeVar("_T", str, int)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalData, MultiModalInputs
|
||||
from .image import ImagePlugin
|
||||
@ -60,7 +61,7 @@ class VideoPlugin(ImagePlugin):
|
||||
if isinstance(data, list) and len(data) == 1:
|
||||
data = data[0]
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
|
||||
video_processor = self._get_hf_video_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
|
Loading…
x
Reference in New Issue
Block a user