[CI/Build] Refactor image test assets (#5821)

This commit is contained in:
Cyrus Leung 2024-06-26 16:02:34 +08:00 committed by GitHub
parent 3439c5a8e3
commit 6984c02a27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 127 additions and 92 deletions

View File

@ -1,7 +1,12 @@
import contextlib import contextlib
import gc import gc
import os import os
from typing import Any, Dict, List, Optional, Tuple, TypeVar from collections import UserList
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypeVar)
import pytest import pytest
import torch import torch
@ -28,21 +33,8 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
# Multi modal related _IMAGE_DIR = Path(_TEST_DIR) / "images"
# You can use `.buildkite/download-images.sh` to download the assets """You can use `.buildkite/download-images.sh` to download the assets."""
PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
@ -51,6 +43,63 @@ def _read_prompts(filename: str) -> List[str]:
return prompts return prompts
@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")
@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")
@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
def for_hf(self) -> Image.Image:
return self.pil_image
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)
raise NotImplementedError
class _ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str
class _ImageAssets(UserList[ImageAsset]):
def __init__(self) -> None:
super().__init__(
[ImageAsset("stop_sign"),
ImageAsset("cherry_blossom")])
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""
Convenience method to define the prompt for each test image.
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return [prompts["stop_sign"], prompts["cherry_blossom"]]
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
def cleanup(): def cleanup():
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment() destroy_distributed_environment()
@ -81,31 +130,6 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup() cleanup()
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in IMAGE_FILES]
@pytest.fixture()
def vllm_images(request) -> List[MultiModalData]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
return [
ImageFeatureData(torch.load(filename))
for filename in IMAGE_FEATURES_FILES
]
else:
return [
ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
]
@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
@pytest.fixture @pytest.fixture
def example_prompts() -> List[str]: def example_prompts() -> List[str]:
prompts = [] prompts = []
@ -122,6 +146,11 @@ def example_long_prompts() -> List[str]:
return prompts return prompts
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,

View File

@ -5,17 +5,17 @@ from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from ..conftest import IMAGE_FILES from ..conftest import IMAGE_ASSETS
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass # The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<image>\nUSER: What's the content of the image?\nASSISTANT:", "<image>\nUSER: What's the content of the image?\nASSISTANT:",
"cherry_blossom":
"<image>\nUSER: What is the season?\nASSISTANT:", "<image>\nUSER: What is the season?\nASSISTANT:",
] })
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_llava_configs(model_name: str): def iter_llava_configs(model_name: str):
@ -49,28 +49,28 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
input_ids, output_str = vllm_output output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id) image_token_str = tokenizer.decode(image_token_id)
hf_input_ids = [ hf_output_ids = [
input_id for idx, input_id in enumerate(input_ids) token_id for idx, token_id in enumerate(output_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id if token_id != image_token_id or output_ids[idx - 1] != image_token_id
] ]
hf_output_str = output_str \ hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") .replace(image_token_str * vlm_config.image_feature_size, "")
return hf_input_ids, hf_output_str return hf_output_ids, hf_output_str
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] # TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
model_and_config, dtype: str, max_tokens: int) -> None: dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images. All the image fixtures for the test is under tests/images.
@ -81,6 +81,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from ..conftest import IMAGE_FILES from ..conftest import IMAGE_ASSETS
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
@ -15,12 +15,12 @@ _PREFACE = (
"questions.") "questions.")
# The image token is placed before "user" on purpose so that the test can pass # The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:", "stop_sign":
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:", f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:",
] "cherry_blossom":
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:",
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) })
def iter_llava_next_configs(model_name: str): def iter_llava_next_configs(model_name: str):
@ -56,20 +56,20 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
input_ids, output_str = vllm_output output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id) image_token_str = tokenizer.decode(image_token_id)
hf_input_ids = [ hf_output_ids = [
input_id for idx, input_id in enumerate(input_ids) token_id for idx, token_id in enumerate(output_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id if token_id != image_token_id or output_ids[idx - 1] != image_token_id
] ]
hf_output_str = output_str \ hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, " ") .replace(image_token_str * vlm_config.image_feature_size, " ")
return hf_input_ids, hf_output_str return hf_output_ids, hf_output_str
@pytest.mark.xfail( @pytest.mark.xfail(
@ -78,8 +78,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
model_and_config, dtype: str, max_tokens: int) -> None: dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images. All the image fixtures for the test is under tests/images.
@ -90,6 +90,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,

View File

@ -6,17 +6,17 @@ from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from vllm.utils import is_cpu from vllm.utils import is_cpu
from ..conftest import IMAGE_FILES from ..conftest import IMAGE_ASSETS
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass # The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", "cherry_blossom":
] "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501
})
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_phi3v_configs(model_name: str): def iter_phi3v_configs(model_name: str):
@ -50,22 +50,22 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
input_ids, output_str = vllm_output output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id) image_token_str = tokenizer.decode(image_token_id)
hf_input_ids = [ hf_output_ids = [
input_id if input_id != image_token_id else 0 token_id if token_id != image_token_id else 0
for idx, input_id in enumerate(input_ids) for idx, token_id in enumerate(output_ids)
] ]
hf_output_str = output_str \ hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") \ .replace(image_token_str * vlm_config.image_feature_size, "") \
.replace("<s>", " ").replace("<|user|>", "") \ .replace("<s>", " ").replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ") .replace("<|end|>\n<|assistant|>", " ")
return hf_input_ids, hf_output_str return hf_output_ids, hf_output_str
target_dtype = "half" target_dtype = "half"
@ -82,8 +82,8 @@ if is_cpu():
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images, def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
model_and_config, dtype: str, max_tokens: int) -> None: dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images. All the image fixtures for the test is under tests/images.
@ -94,6 +94,8 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn # use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"} hf_model_kwargs = {"_attn_implementation": "eager"}

View File

@ -10,7 +10,7 @@ from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
@pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("dtype", ["half", "float"])
def test_clip_image_processor(hf_images, dtype): def test_clip_image_processor(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf" MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560 IMAGE_HEIGHT = IMAGE_WIDTH = 560
@ -35,13 +35,13 @@ def test_clip_image_processor(hf_images, dtype):
image_processor_revision=None, image_processor_revision=None,
) )
for image in hf_images: for asset in image_assets:
hf_result = hf_processor.preprocess( hf_result = hf_processor.preprocess(
image, asset.pil_image,
return_tensors="pt", return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input( vllm_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image), ImagePixelData(asset.pil_image),
model_config=model_config, model_config=model_config,
vlm_config=vlm_config, vlm_config=vlm_config,
) )
@ -59,7 +59,7 @@ def test_clip_image_processor(hf_images, dtype):
reason="Inconsistent image processor being used due to lack " reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement") "of support for dynamic image token replacement")
@pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("dtype", ["half", "float"])
def test_llava_next_image_processor(hf_images, dtype): def test_llava_next_image_processor(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560 IMAGE_HEIGHT = IMAGE_WIDTH = 560
@ -84,13 +84,13 @@ def test_llava_next_image_processor(hf_images, dtype):
image_processor_revision=None, image_processor_revision=None,
) )
for image in hf_images: for asset in image_assets:
hf_result = hf_processor.preprocess( hf_result = hf_processor.preprocess(
image, asset.pil_image,
return_tensors="pt", return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input( vllm_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image), ImagePixelData(asset.pil_image),
model_config=model_config, model_config=model_config,
vlm_config=vlm_config, vlm_config=vlm_config,
) )
@ -107,7 +107,7 @@ def test_llava_next_image_processor(hf_images, dtype):
@pytest.mark.xfail( @pytest.mark.xfail(
reason="Example image pixels were not processed using HuggingFace") reason="Example image pixels were not processed using HuggingFace")
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): def test_image_pixel_types(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf" MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560 IMAGE_HEIGHT = IMAGE_WIDTH = 560
@ -129,14 +129,14 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
image_processor_revision=None, image_processor_revision=None,
) )
for image, tensor in zip(hf_images, vllm_image_tensors): for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.process_input( image_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image), ImagePixelData(asset.pil_image),
model_config=model_config, model_config=model_config,
vlm_config=vlm_config, vlm_config=vlm_config,
) )
tensor_result = MULTIMODAL_REGISTRY.process_input( tensor_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(tensor), ImagePixelData(asset.pixel_values),
model_config=model_config, model_config=model_config,
vlm_config=vlm_config, vlm_config=vlm_config,
) )