[CI/Build] Refactor image test assets (#5821)
This commit is contained in:
parent
3439c5a8e3
commit
6984c02a27
@ -1,7 +1,12 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
from collections import UserList
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
|
||||
TypeVar)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -28,21 +33,8 @@ _TEST_DIR = os.path.dirname(__file__)
|
||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
# Multi modal related
|
||||
# You can use `.buildkite/download-images.sh` to download the assets
|
||||
PIXEL_VALUES_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
|
||||
]
|
||||
IMAGE_FEATURES_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
|
||||
]
|
||||
IMAGE_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename)
|
||||
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
|
||||
]
|
||||
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
|
||||
_IMAGE_DIR = Path(_TEST_DIR) / "images"
|
||||
"""You can use `.buildkite/download-images.sh` to download the assets."""
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> List[str]:
|
||||
@ -51,6 +43,63 @@ def _read_prompts(filename: str) -> List[str]:
|
||||
return prompts
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageAsset:
|
||||
name: Literal["stop_sign", "cherry_blossom"]
|
||||
|
||||
@cached_property
|
||||
def pixel_values(self) -> torch.Tensor:
|
||||
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")
|
||||
|
||||
@cached_property
|
||||
def image_features(self) -> torch.Tensor:
|
||||
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")
|
||||
|
||||
@cached_property
|
||||
def pil_image(self) -> Image.Image:
|
||||
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
|
||||
|
||||
def for_hf(self) -> Image.Image:
|
||||
return self.pil_image
|
||||
|
||||
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
|
||||
image_input_type = vision_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
return ImageFeatureData(self.image_features)
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
return ImagePixelData(self.pil_image)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _ImageAssetPrompts(TypedDict):
|
||||
stop_sign: str
|
||||
cherry_blossom: str
|
||||
|
||||
|
||||
class _ImageAssets(UserList[ImageAsset]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
[ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom")])
|
||||
|
||||
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
|
||||
"""
|
||||
Convenience method to define the prompt for each test image.
|
||||
|
||||
The order of the returned prompts matches the order of the
|
||||
assets when iterating through this object.
|
||||
"""
|
||||
return [prompts["stop_sign"], prompts["cherry_blossom"]]
|
||||
|
||||
|
||||
IMAGE_ASSETS = _ImageAssets()
|
||||
"""Singleton instance of :class:`_ImageAssets`."""
|
||||
|
||||
|
||||
def cleanup():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
@ -81,31 +130,6 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_images() -> List[Image.Image]:
|
||||
return [Image.open(filename) for filename in IMAGE_FILES]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_images(request) -> List[MultiModalData]:
|
||||
vision_language_config = request.getfixturevalue("model_and_config")[1]
|
||||
if vision_language_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
|
||||
return [
|
||||
ImageFeatureData(torch.load(filename))
|
||||
for filename in IMAGE_FEATURES_FILES
|
||||
]
|
||||
else:
|
||||
return [
|
||||
ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_image_tensors(request) -> List[torch.Tensor]:
|
||||
return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
prompts = []
|
||||
@ -122,6 +146,11 @@ def example_long_prompts() -> List[str]:
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def image_assets() -> _ImageAssets:
|
||||
return IMAGE_ASSETS
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
|
@ -5,17 +5,17 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom":
|
||||
"<image>\nUSER: What is the season?\nASSISTANT:",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
})
|
||||
|
||||
|
||||
def iter_llava_configs(model_name: str):
|
||||
@ -49,28 +49,28 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
input_ids, output_str = vllm_output
|
||||
output_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id for idx, input_id in enumerate(input_ids)
|
||||
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
return hf_output_ids, hf_output_str
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -81,6 +81,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
@ -15,12 +15,12 @@ _PREFACE = (
|
||||
"questions.")
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:",
|
||||
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom":
|
||||
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:",
|
||||
})
|
||||
|
||||
|
||||
def iter_llava_next_configs(model_name: str):
|
||||
@ -56,20 +56,20 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
input_ids, output_str = vllm_output
|
||||
output_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id for idx, input_id in enumerate(input_ids)
|
||||
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, " ")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
return hf_output_ids, hf_output_str
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
@ -78,8 +78,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -90,6 +90,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
|
@ -6,17 +6,17 @@ from transformers import AutoTokenizer
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
"cherry_blossom":
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
})
|
||||
|
||||
|
||||
def iter_phi3v_configs(model_name: str):
|
||||
@ -50,22 +50,22 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
input_ids, output_str = vllm_output
|
||||
output_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id if input_id != image_token_id else 0
|
||||
for idx, input_id in enumerate(input_ids)
|
||||
hf_output_ids = [
|
||||
token_id if token_id != image_token_id else 0
|
||||
for idx, token_id in enumerate(output_ids)
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "") \
|
||||
.replace("<s>", " ").replace("<|user|>", "") \
|
||||
.replace("<|end|>\n<|assistant|>", " ")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
return hf_output_ids, hf_output_str
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
@ -82,8 +82,8 @@ if is_cpu():
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -94,6 +94,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
hf_images = [asset.for_hf() for asset in image_assets]
|
||||
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||
|
@ -10,7 +10,7 @@ from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_clip_image_processor(hf_images, dtype):
|
||||
def test_clip_image_processor(image_assets, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
@ -35,13 +35,13 @@ def test_clip_image_processor(hf_images, dtype):
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image in hf_images:
|
||||
for asset in image_assets:
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
asset.pil_image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
@ -59,7 +59,7 @@ def test_clip_image_processor(hf_images, dtype):
|
||||
reason="Inconsistent image processor being used due to lack "
|
||||
"of support for dynamic image token replacement")
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
def test_llava_next_image_processor(hf_images, dtype):
|
||||
def test_llava_next_image_processor(image_assets, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
@ -84,13 +84,13 @@ def test_llava_next_image_processor(hf_images, dtype):
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image in hf_images:
|
||||
for asset in image_assets:
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
asset.pil_image,
|
||||
return_tensors="pt",
|
||||
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
|
||||
vllm_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
@ -107,7 +107,7 @@ def test_llava_next_image_processor(hf_images, dtype):
|
||||
@pytest.mark.xfail(
|
||||
reason="Example image pixels were not processed using HuggingFace")
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
|
||||
def test_image_pixel_types(image_assets, dtype):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
IMAGE_HEIGHT = IMAGE_WIDTH = 560
|
||||
|
||||
@ -129,14 +129,14 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
|
||||
image_processor_revision=None,
|
||||
)
|
||||
|
||||
for image, tensor in zip(hf_images, vllm_image_tensors):
|
||||
for asset in image_assets:
|
||||
image_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(image),
|
||||
ImagePixelData(asset.pil_image),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
tensor_result = MULTIMODAL_REGISTRY.process_input(
|
||||
ImagePixelData(tensor),
|
||||
ImagePixelData(asset.pixel_values),
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user