[CI/Build] Move model-specific multi-modal processing tests (#11934)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
c32a7c7c0c
commit
7a3a83e3b8
@ -368,6 +368,7 @@ steps:
|
|||||||
- tests/models/encoder_decoder/vision_language
|
- tests/models/encoder_decoder/vision_language
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
|
- pytest -v -s models/multimodal
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s models/embedding/vision_language -m core_model
|
- pytest -v -s models/embedding/vision_language -m core_model
|
||||||
|
0
tests/models/multimodal/processing/__init__.py
Normal file
0
tests/models/multimodal/processing/__init__.py
Normal file
201
tests/models/multimodal/processing/test_common.py
Normal file
201
tests/models/multimodal/processing/test_common.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.inputs import InputProcessingContext
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.processing import ProcessingCache
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
|
from ....multimodal.utils import random_audio, random_image, random_video
|
||||||
|
|
||||||
|
|
||||||
|
def _test_processing_correctness(
|
||||||
|
model_id: str,
|
||||||
|
modalities: dict[str, bool],
|
||||||
|
hit_rate: float,
|
||||||
|
num_batches: int,
|
||||||
|
simplify_rate: float,
|
||||||
|
):
|
||||||
|
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
|
||||||
|
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
|
||||||
|
else:
|
||||||
|
hf_overrides = {}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {
|
||||||
|
modality: 3 if supports_multi else 1
|
||||||
|
for modality, supports_multi in modalities.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
|
ctx = InputProcessingContext(
|
||||||
|
model_config,
|
||||||
|
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||||
|
)
|
||||||
|
# Ensure that it can fit all of the data
|
||||||
|
cache = ProcessingCache(capacity=1 << 30)
|
||||||
|
|
||||||
|
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||||
|
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||||
|
dummy_inputs = baseline_processor.dummy_inputs
|
||||||
|
tokenizer = baseline_processor.info.get_tokenizer()
|
||||||
|
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
|
||||||
|
input_to_hit = {
|
||||||
|
"image": Image.new("RGB", size=(128, 128)),
|
||||||
|
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
|
||||||
|
"audio": (np.zeros((512, )), 16000),
|
||||||
|
}
|
||||||
|
input_factory = {
|
||||||
|
"image":
|
||||||
|
partial(random_image, rng, min_wh=128, max_wh=256),
|
||||||
|
"video":
|
||||||
|
partial(random_video,
|
||||||
|
rng,
|
||||||
|
min_frames=2,
|
||||||
|
max_frames=8,
|
||||||
|
min_wh=128,
|
||||||
|
max_wh=256),
|
||||||
|
"audio":
|
||||||
|
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||||
|
}
|
||||||
|
|
||||||
|
for batch_idx in range(num_batches):
|
||||||
|
mm_data = {
|
||||||
|
k:
|
||||||
|
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||||
|
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
|
||||||
|
for k in modalities
|
||||||
|
}
|
||||||
|
|
||||||
|
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||||
|
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||||
|
model_config.max_model_len,
|
||||||
|
mm_counts,
|
||||||
|
).prompt_text
|
||||||
|
|
||||||
|
# Drop unnecessary keys and test single -> multi conversion
|
||||||
|
if rng.rand() < simplify_rate:
|
||||||
|
for k in list(mm_data.keys()):
|
||||||
|
if not mm_data[k]:
|
||||||
|
del mm_data[k]
|
||||||
|
elif len(mm_data[k]) == 1:
|
||||||
|
mm_data[k] = mm_data[k][0]
|
||||||
|
|
||||||
|
baseline_result = baseline_processor.apply(
|
||||||
|
prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
cached_result = cached_processor.apply(
|
||||||
|
prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert baseline_result == cached_result, (
|
||||||
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||||
|
|
||||||
|
baseline_tokenized_result = baseline_processor.apply(
|
||||||
|
tokenizer.encode(prompt),
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert baseline_result == baseline_tokenized_result, (
|
||||||
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||||
|
|
||||||
|
cached_tokenized_result = cached_processor.apply(
|
||||||
|
tokenizer.encode(prompt),
|
||||||
|
mm_data=mm_data,
|
||||||
|
hf_processor_mm_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cached_result == cached_tokenized_result, (
|
||||||
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
# True if the model supports multiple data items of the modality per request
|
||||||
|
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||||
|
("rhymes-ai/Aria", {"image": True}),
|
||||||
|
("Salesforce/blip2-opt-2.7b", {"image": False}),
|
||||||
|
("facebook/chameleon-7b", {"image": False}),
|
||||||
|
("adept/fuyu-8b", {"image": False}),
|
||||||
|
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||||
|
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||||
|
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
||||||
|
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
||||||
|
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||||
|
("mistral-community/pixtral-12b", {"image": True}),
|
||||||
|
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||||
|
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
|
||||||
|
("fixie-ai/ultravox-v0_3", {"audio": True}),
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||||
|
# yapf: enable
|
||||||
|
def test_processing_correctness(
|
||||||
|
model_id: str,
|
||||||
|
modalities: dict[str, bool],
|
||||||
|
hit_rate: float,
|
||||||
|
num_batches: int,
|
||||||
|
simplify_rate: float,
|
||||||
|
):
|
||||||
|
_test_processing_correctness(
|
||||||
|
model_id,
|
||||||
|
modalities,
|
||||||
|
hit_rate=hit_rate,
|
||||||
|
num_batches=num_batches,
|
||||||
|
simplify_rate=simplify_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||||
|
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||||
|
# yapf: enable
|
||||||
|
def test_processing_correctness_phi3v(
|
||||||
|
model_id: str,
|
||||||
|
modalities: dict[str, bool],
|
||||||
|
hit_rate: float,
|
||||||
|
num_batches: int,
|
||||||
|
simplify_rate: float,
|
||||||
|
):
|
||||||
|
# HACK - this is an attempted workaround for the following bug
|
||||||
|
# https://github.com/huggingface/transformers/issues/34307
|
||||||
|
from transformers import AutoImageProcessor # noqa: F401
|
||||||
|
from transformers import AutoProcessor # noqa: F401
|
||||||
|
|
||||||
|
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||||
|
|
||||||
|
_test_processing_correctness(
|
||||||
|
model_id,
|
||||||
|
modalities,
|
||||||
|
hit_rate=hit_rate,
|
||||||
|
num_batches=num_batches,
|
||||||
|
simplify_rate=simplify_rate,
|
||||||
|
)
|
@ -8,8 +8,8 @@ from transformers import AutoImageProcessor, AutoTokenizer
|
|||||||
from vllm.inputs import InputContext, token_inputs
|
from vllm.inputs import InputContext, token_inputs
|
||||||
from vllm.multimodal import MultiModalRegistry
|
from vllm.multimodal import MultiModalRegistry
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from ....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
|
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
|
||||||
|
|
@ -7,8 +7,8 @@ from transformers import AutoTokenizer
|
|||||||
from vllm.inputs import InputContext, token_inputs
|
from vllm.inputs import InputContext, token_inputs
|
||||||
from vllm.multimodal import MultiModalRegistry
|
from vllm.multimodal import MultiModalRegistry
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from ....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
models = ["OpenGVLab/InternVL2-2B"]
|
models = ["OpenGVLab/InternVL2-2B"]
|
||||||
|
|
@ -10,7 +10,7 @@ from vllm.multimodal.parse import ImageSize
|
|||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_prompt_replacements_one(
|
def _validate_image_prompt_replacements_one(
|
@ -10,7 +10,7 @@ from vllm.multimodal.parse import ImageSize
|
|||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_prompt_replacements_one(
|
def _validate_image_prompt_replacements_one(
|
@ -4,8 +4,8 @@ import pytest
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from ....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
@ -9,8 +9,8 @@ from vllm.inputs import InputContext, token_inputs
|
|||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from .....conftest import IMAGE_ASSETS
|
from ....conftest import IMAGE_ASSETS
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
### Multimodal preprocessing tests
|
### Multimodal preprocessing tests
|
||||||
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
|
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
|
@ -3,8 +3,8 @@ import pytest
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from ....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
@ -1,30 +1,25 @@
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
# yapf conflicts with isort for this block
|
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement,
|
||||||
# yapf: disable
|
|
||||||
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
|
|
||||||
PromptReplacement,
|
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
find_text_matches, find_token_matches,
|
find_text_matches, find_token_matches,
|
||||||
iter_token_matches,
|
iter_token_matches,
|
||||||
replace_text_matches,
|
replace_text_matches,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
|
||||||
from vllm.multimodal.profiling import MultiModalProfiler
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import full_groupby
|
from vllm.utils import full_groupby
|
||||||
|
|
||||||
|
from .utils import random_image
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -531,37 +526,6 @@ def test_find_mm_placeholders(
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
|
|
||||||
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
|
||||||
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
|
|
||||||
return Image.fromarray(arr)
|
|
||||||
|
|
||||||
|
|
||||||
def _rand_video(
|
|
||||||
rng: np.random.RandomState,
|
|
||||||
min_frames: int,
|
|
||||||
max_frames: int,
|
|
||||||
min_wh: int,
|
|
||||||
max_wh: int,
|
|
||||||
):
|
|
||||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
|
||||||
num_frames = rng.randint(min_frames, max_frames)
|
|
||||||
num_frames = (num_frames // 2) * 2
|
|
||||||
|
|
||||||
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
|
||||||
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def _rand_audio(
|
|
||||||
rng: np.random.RandomState,
|
|
||||||
min_len: int,
|
|
||||||
max_len: int,
|
|
||||||
sr: int,
|
|
||||||
):
|
|
||||||
audio_len = rng.randint(min_len, max_len)
|
|
||||||
return rng.rand(audio_len), sr
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("limit", "num_supported", "is_valid"),
|
("limit", "num_supported", "is_valid"),
|
||||||
@ -628,7 +592,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
|||||||
)
|
)
|
||||||
|
|
||||||
rng = np.random.RandomState(0)
|
rng = np.random.RandomState(0)
|
||||||
image = _rand_img(rng, min_wh=128, max_wh=256)
|
image = random_image(rng, min_wh=128, max_wh=256)
|
||||||
if num_images == 0:
|
if num_images == 0:
|
||||||
mm_data = {}
|
mm_data = {}
|
||||||
elif num_images == 1:
|
elif num_images == 1:
|
||||||
@ -647,191 +611,3 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
|||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
hf_processor_mm_kwargs={},
|
hf_processor_mm_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_processing_correctness(
|
|
||||||
model_id: str,
|
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
|
||||||
num_batches: int,
|
|
||||||
simplify_rate: float,
|
|
||||||
):
|
|
||||||
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
|
|
||||||
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
|
|
||||||
else:
|
|
||||||
hf_overrides = {}
|
|
||||||
|
|
||||||
limit_mm_per_prompt = {
|
|
||||||
modality: 3 if supports_multi else 1
|
|
||||||
for modality, supports_multi in modalities.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model_id,
|
|
||||||
task="auto",
|
|
||||||
tokenizer=model_id,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
seed=0,
|
|
||||||
dtype="float16",
|
|
||||||
revision=None,
|
|
||||||
hf_overrides=hf_overrides,
|
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
||||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
|
||||||
ctx = InputProcessingContext(
|
|
||||||
model_config,
|
|
||||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
|
||||||
)
|
|
||||||
# Ensure that it can fit all of the data
|
|
||||||
cache = ProcessingCache(capacity=1 << 30)
|
|
||||||
|
|
||||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
|
||||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
|
||||||
dummy_inputs = baseline_processor.dummy_inputs
|
|
||||||
tokenizer = baseline_processor.info.get_tokenizer()
|
|
||||||
|
|
||||||
rng = np.random.RandomState(0)
|
|
||||||
|
|
||||||
input_to_hit = {
|
|
||||||
"image": Image.new("RGB", size=(128, 128)),
|
|
||||||
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
|
|
||||||
"audio": (np.zeros((512, )), 16000),
|
|
||||||
}
|
|
||||||
input_factory = {
|
|
||||||
"image":
|
|
||||||
partial(_rand_img, rng, min_wh=128, max_wh=256),
|
|
||||||
"video":
|
|
||||||
partial(_rand_video,
|
|
||||||
rng,
|
|
||||||
min_frames=2,
|
|
||||||
max_frames=8,
|
|
||||||
min_wh=128,
|
|
||||||
max_wh=256),
|
|
||||||
"audio":
|
|
||||||
partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
|
|
||||||
}
|
|
||||||
|
|
||||||
for batch_idx in range(num_batches):
|
|
||||||
mm_data = {
|
|
||||||
k:
|
|
||||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
|
||||||
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
|
|
||||||
for k in modalities
|
|
||||||
}
|
|
||||||
|
|
||||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
|
||||||
prompt = dummy_inputs.get_dummy_processor_inputs(
|
|
||||||
model_config.max_model_len,
|
|
||||||
mm_counts,
|
|
||||||
).prompt_text
|
|
||||||
|
|
||||||
# Drop unnecessary keys and test single -> multi conversion
|
|
||||||
if rng.rand() < simplify_rate:
|
|
||||||
for k in list(mm_data.keys()):
|
|
||||||
if not mm_data[k]:
|
|
||||||
del mm_data[k]
|
|
||||||
elif len(mm_data[k]) == 1:
|
|
||||||
mm_data[k] = mm_data[k][0]
|
|
||||||
|
|
||||||
baseline_result = baseline_processor.apply(
|
|
||||||
prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
cached_result = cached_processor.apply(
|
|
||||||
prompt,
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert baseline_result == cached_result, (
|
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
||||||
|
|
||||||
baseline_tokenized_result = baseline_processor.apply(
|
|
||||||
tokenizer.encode(prompt),
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert baseline_result == baseline_tokenized_result, (
|
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
||||||
|
|
||||||
cached_tokenized_result = cached_processor.apply(
|
|
||||||
tokenizer.encode(prompt),
|
|
||||||
mm_data=mm_data,
|
|
||||||
hf_processor_mm_kwargs={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cached_result == cached_tokenized_result, (
|
|
||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
# True if the model supports multiple data items of the modality per request
|
|
||||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
|
||||||
("rhymes-ai/Aria", {"image": True}),
|
|
||||||
("Salesforce/blip2-opt-2.7b", {"image": False}),
|
|
||||||
("facebook/chameleon-7b", {"image": False}),
|
|
||||||
("adept/fuyu-8b", {"image": False}),
|
|
||||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
|
||||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
|
||||||
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
|
||||||
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
|
||||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
|
||||||
("mistral-community/pixtral-12b", {"image": True}),
|
|
||||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
|
||||||
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
|
|
||||||
("fixie-ai/ultravox-v0_3", {"audio": True}),
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
|
||||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
|
||||||
# yapf: enable
|
|
||||||
def test_processing_correctness(
|
|
||||||
model_id: str,
|
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
|
||||||
num_batches: int,
|
|
||||||
simplify_rate: float,
|
|
||||||
):
|
|
||||||
_test_processing_correctness(
|
|
||||||
model_id,
|
|
||||||
modalities,
|
|
||||||
hit_rate=hit_rate,
|
|
||||||
num_batches=num_batches,
|
|
||||||
simplify_rate=simplify_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
|
||||||
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
|
||||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
|
||||||
# yapf: enable
|
|
||||||
def test_processing_correctness_phi3v(
|
|
||||||
model_id: str,
|
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
|
||||||
num_batches: int,
|
|
||||||
simplify_rate: float,
|
|
||||||
):
|
|
||||||
# HACK - this is an attempted workaround for the following bug
|
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
|
||||||
from transformers import AutoImageProcessor # noqa: F401
|
|
||||||
from transformers import AutoProcessor # noqa: F401
|
|
||||||
|
|
||||||
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
||||||
|
|
||||||
_test_processing_correctness(
|
|
||||||
model_id,
|
|
||||||
modalities,
|
|
||||||
hit_rate=hit_rate,
|
|
||||||
num_batches=num_batches,
|
|
||||||
simplify_rate=simplify_rate,
|
|
||||||
)
|
|
||||||
|
33
tests/multimodal/utils.py
Normal file
33
tests/multimodal/utils.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int):
|
||||||
|
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
||||||
|
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
|
||||||
|
return Image.fromarray(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def random_video(
|
||||||
|
rng: np.random.RandomState,
|
||||||
|
min_frames: int,
|
||||||
|
max_frames: int,
|
||||||
|
min_wh: int,
|
||||||
|
max_wh: int,
|
||||||
|
):
|
||||||
|
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
||||||
|
num_frames = rng.randint(min_frames, max_frames)
|
||||||
|
num_frames = (num_frames // 2) * 2
|
||||||
|
|
||||||
|
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
||||||
|
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def random_audio(
|
||||||
|
rng: np.random.RandomState,
|
||||||
|
min_len: int,
|
||||||
|
max_len: int,
|
||||||
|
sr: int,
|
||||||
|
):
|
||||||
|
audio_len = rng.randint(min_len, max_len)
|
||||||
|
return rng.rand(audio_len), sr
|
Loading…
x
Reference in New Issue
Block a user