
Signed-off-by: hzh <hezhihui_thu@163.com> Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com> Signed-off-by: Akshat Tripathi <akshat@krai.ai> Signed-off-by: Oleg Mosalov <oleg@krai.ai> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Signed-off-by: Yida Wu <yidawu@alumni.cmu.edu> Signed-off-by: Chenguang Li <757486878@qq.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Shanshan Shen <467638484@qq.com> Signed-off-by: elijah <f1renze.142857@gmail.com> Signed-off-by: Yikun <yikunkero@gmail.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Konrad Zawora <kzawora@habana.ai> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Co-authored-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Co-authored-by: shaochangxu <85155497+shaochangxu@users.noreply.github.com> Co-authored-by: shaochangxu.scx <shaochangxu.scx@antgroup.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: sixgod <evethwillbeok@outlook.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Rafael Vasquez <rafvasq21@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Akshat Tripathi <Akshat.tripathi6568@gmail.com> Co-authored-by: Oleg Mosalov <oleg@krai.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Yangcheng Li <liyangcheng.lyc@alibaba-inc.com> Co-authored-by: Siyuan Li <94890248+liaoyanqing666@users.noreply.github.com> Co-authored-by: Concurrensee <yida.wu@amd.com> Co-authored-by: Chenguang Li <757486878@qq.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Alex Brooks <alex.brooks@ibm.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: elijah <30852919+e1ijah1@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Konrad Zawora <kzawora@habana.ai> Co-authored-by: TJian <tunjian1996@gmail.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: maang-h <55082429+maang-h@users.noreply.github.com> Co-authored-by: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
from functools import partial
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.inputs import InputProcessingContext
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.processing import ProcessingCache
|
|
from vllm.multimodal.utils import cached_get_tokenizer
|
|
|
|
from ....multimodal.utils import random_audio, random_image, random_video
|
|
from ...registry import HF_EXAMPLE_MODELS
|
|
|
|
|
|
def _test_processing_correctness(
|
|
model_id: str,
|
|
hit_rate: float,
|
|
num_batches: int,
|
|
simplify_rate: float,
|
|
):
|
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
|
model_info.check_available_online(on_fail="skip")
|
|
model_info.check_transformers_version(on_fail="skip")
|
|
|
|
model_config = ModelConfig(
|
|
model_id,
|
|
task="auto",
|
|
tokenizer=model_id,
|
|
tokenizer_mode="auto",
|
|
trust_remote_code=model_info.trust_remote_code,
|
|
seed=0,
|
|
dtype="float16",
|
|
revision=None,
|
|
hf_overrides=model_info.hf_overrides,
|
|
)
|
|
|
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
|
ctx = InputProcessingContext(
|
|
model_config,
|
|
tokenizer=cached_get_tokenizer(
|
|
model_config.tokenizer,
|
|
trust_remote_code=model_info.trust_remote_code,
|
|
),
|
|
)
|
|
# Ensure that it can fit all of the data
|
|
cache = ProcessingCache(capacity=1 << 30)
|
|
|
|
processing_info = factories.info(ctx)
|
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
|
limit_mm_per_prompt = {
|
|
modality: 3 if limit is None else limit
|
|
for modality, limit in supported_mm_limits.items()
|
|
}
|
|
|
|
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
|
|
|
baseline_processor = factories.build_processor(ctx, cache=None)
|
|
cached_processor = factories.build_processor(ctx, cache=cache)
|
|
dummy_inputs = baseline_processor.dummy_inputs
|
|
tokenizer = baseline_processor.info.get_tokenizer()
|
|
|
|
rng = np.random.RandomState(0)
|
|
|
|
input_to_hit = {
|
|
"image": Image.new("RGB", size=(128, 128)),
|
|
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
|
|
"audio": (np.zeros((512, )), 16000),
|
|
}
|
|
input_factory = {
|
|
"image":
|
|
partial(random_image, rng, min_wh=128, max_wh=256),
|
|
"video":
|
|
partial(random_video,
|
|
rng,
|
|
min_frames=2,
|
|
max_frames=8,
|
|
min_wh=128,
|
|
max_wh=256),
|
|
"audio":
|
|
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
|
}
|
|
|
|
for batch_idx in range(num_batches):
|
|
mm_data = {
|
|
k:
|
|
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
|
for _ in range(rng.randint(limit))]
|
|
for k, limit in limit_mm_per_prompt.items()
|
|
}
|
|
|
|
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
|
prompt = dummy_inputs.get_dummy_processor_inputs(
|
|
model_config.max_model_len,
|
|
mm_counts,
|
|
).prompt_text
|
|
|
|
# Drop unnecessary keys and test single -> multi conversion
|
|
if rng.rand() < simplify_rate:
|
|
for k in list(mm_data.keys()):
|
|
if not mm_data[k]:
|
|
del mm_data[k]
|
|
elif len(mm_data[k]) == 1:
|
|
mm_data[k] = mm_data[k][0]
|
|
|
|
baseline_result = baseline_processor.apply(
|
|
prompt,
|
|
mm_data=mm_data,
|
|
hf_processor_mm_kwargs={},
|
|
)
|
|
cached_result = cached_processor.apply(
|
|
prompt,
|
|
mm_data=mm_data,
|
|
hf_processor_mm_kwargs={},
|
|
)
|
|
|
|
assert baseline_result == cached_result, (
|
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
|
|
baseline_tokenized_result = baseline_processor.apply(
|
|
tokenizer.encode(prompt),
|
|
mm_data=mm_data,
|
|
hf_processor_mm_kwargs={},
|
|
)
|
|
|
|
assert baseline_result == baseline_tokenized_result, (
|
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
|
|
cached_tokenized_result = cached_processor.apply(
|
|
tokenizer.encode(prompt),
|
|
mm_data=mm_data,
|
|
hf_processor_mm_kwargs={},
|
|
)
|
|
|
|
assert cached_result == cached_tokenized_result, (
|
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
|
|
|
|
|
# yapf: disable
|
|
# True if the model supports multiple data items of the modality per request
|
|
@pytest.mark.parametrize("model_id", [
|
|
"rhymes-ai/Aria",
|
|
"Salesforce/blip2-opt-2.7b",
|
|
"facebook/chameleon-7b",
|
|
"deepseek-ai/deepseek-vl2-tiny",
|
|
"adept/fuyu-8b",
|
|
"llava-hf/llava-1.5-7b-hf",
|
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
|
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
|
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
|
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
|
"mistral-community/pixtral-12b",
|
|
"openbmb/MiniCPM-o-2_6",
|
|
"openbmb/MiniCPM-V-2_6",
|
|
"Qwen/Qwen-VL-Chat",
|
|
"Qwen/Qwen2-VL-2B-Instruct",
|
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
|
"fixie-ai/ultravox-v0_3",
|
|
])
|
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
|
@pytest.mark.parametrize("num_batches", [32])
|
|
@pytest.mark.parametrize("simplify_rate", [1.0])
|
|
# yapf: enable
|
|
def test_processing_correctness(
|
|
model_id: str,
|
|
hit_rate: float,
|
|
num_batches: int,
|
|
simplify_rate: float,
|
|
):
|
|
_test_processing_correctness(
|
|
model_id,
|
|
hit_rate=hit_rate,
|
|
num_batches=num_batches,
|
|
simplify_rate=simplify_rate,
|
|
)
|
|
|
|
|
|
# yapf: disable
|
|
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"])
|
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
|
@pytest.mark.parametrize("num_batches", [32])
|
|
@pytest.mark.parametrize("simplify_rate", [1.0])
|
|
# yapf: enable
|
|
def test_processing_correctness_phi3v(
|
|
model_id: str,
|
|
hit_rate: float,
|
|
num_batches: int,
|
|
simplify_rate: float,
|
|
):
|
|
# HACK - this is an attempted workaround for the following bug
|
|
# https://github.com/huggingface/transformers/issues/34307
|
|
from transformers import AutoImageProcessor # noqa: F401
|
|
from transformers import AutoProcessor # noqa: F401
|
|
|
|
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
|
|
_test_processing_correctness(
|
|
model_id,
|
|
hit_rate=hit_rate,
|
|
num_batches=num_batches,
|
|
simplify_rate=simplify_rate,
|
|
)
|