Cyrus Leung eed11ebee9
[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-01-04 11:40:53 +00:00

145 lines
5.5 KiB
Python

"""Tests for Qwen's multimodal preprocessing kwargs."""
from typing import Dict, List, Union
import pytest
import torch
from PIL.Image import Image
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import IMAGE_ASSETS
from ....utils import build_model_context
### Multimodal preprocessing tests
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
# These values are specific to Qwen-VL/Chat; we can get these from the model
# config also, but they are hardcoded here to keep the parameterize/fixtures
# easy to read.
IMG_START_ID = 151857
IMG_END_ID = 151858
IMG_PAD_ID = 151859
TOKS_PER_IMG = 256
VIS_ENC_DIM = 4096
IMG_SIZE = 448
@pytest.fixture()
def input_mapper_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_mapper_for_qwen
return input_mapper_for_qwen
@pytest.fixture()
def input_processor_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_processor_for_qwen
return input_processor_for_qwen
@pytest.fixture()
def qwen_vl_context() -> InputContext:
"""Get an InputContext for Qwen-VL."""
return build_model_context(model_name="Qwen/Qwen-VL",
trust_remote_code=True)
# Happy path tests for single/multi-image scenarios for the multimodal
# input processor and mapper, respectively
@pytest.mark.parametrize("num_images", [1, 2])
def test_input_processor_valid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
num_images: int):
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = token_inputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
assert isinstance(proc_inputs, dict)
# Each image should have one start / stop and a fixed context of 256
proc_tokens = proc_inputs["prompt_token_ids"]
assert proc_tokens.count(IMG_START_ID) == num_images
assert proc_tokens.count(IMG_END_ID) == num_images
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
@pytest.mark.parametrize(
"img_data,expected_shape",
[
# single / multi-image
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
# single / multi-image embeddings
(torch.rand(
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
])
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image],
Image],
expected_shape: List[int]):
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
# Ensure that we get the appropriately shaped pixel_values
# for images and image embeddings, respectively.
assert isinstance(mapped_img_data, MultiModalKwargs)
assert "pixel_values" in mapped_img_data
assert mapped_img_data["pixel_values"].shape == expected_shape
# Sad path tests for the multimodal input processor and mapper, respectively
@pytest.mark.parametrize("mm_data", [
{
"image": torch.rand(5)
},
{
"image": torch.rand((5, 5, 5, 5, 5))
},
])
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
mm_data: Dict[str, torch.Tensor]):
"""Test sad cases validated in Qwen's multimodal input processor."""
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs)
@pytest.mark.parametrize(
"img_data",
[
# Wrong context length
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
# Wrong visual encoder output size
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
])
def test_input_mapper_invalid_mm_data(
input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image], Image],
):
"""Sad cases validated in Qwen VL's multimodal input mapper."""
with pytest.raises(ValueError):
input_mapper_for_qwen(qwen_vl_context, img_data)