[CI/Build] Add Model Tests for Qwen2-VL (#9846)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5608e611c2
commit
16b8f7a86f
@ -9,6 +9,7 @@
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# nightly(bool): run this test in nightly pipeline only
|
||||
# optional(bool): never run this test by default (i.e. need to unblock manually)
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
# commands(list): the list of commands to run for test. incompatbile with command.
|
||||
@ -330,18 +331,28 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py --ignore=models/decoder_only/language/test_big_models.py
|
||||
|
||||
- label: Decoder-only Multi-Modal Models Test # 1h31min
|
||||
- label: Decoder-only Multi-Modal Models Test (Standard)
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/audio_language
|
||||
- pytest -v -s models/decoder_only/audio_language -m core_model
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
|
||||
|
||||
- label: Decoder-only Multi-Modal Models Test (Extended)
|
||||
nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model'
|
||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model'
|
||||
|
||||
- label: Other Models Test # 6min
|
||||
#mirror_hardwares: [amd]
|
||||
|
@ -262,10 +262,9 @@ def run_qwen2_vl(question: str, modality: str):
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
# Tested on L40
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={
|
||||
|
@ -158,6 +158,7 @@ def run_multi_audio_test(
|
||||
assert all(tokens for tokens, *_ in vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@ -178,6 +179,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
|
@ -17,7 +17,7 @@ MAX_PIXELS = "max_pixels"
|
||||
|
||||
|
||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
||||
# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple
|
||||
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
|
||||
# input mappers.
|
||||
@pytest.fixture()
|
||||
def image_input_mapper_for_qwen2_vl():
|
||||
|
@ -75,6 +75,63 @@ COMMON_BROADCAST_SETTINGS = {
|
||||
# this is a good idea for checking your command first, since tests are slow.
|
||||
|
||||
VLM_TEST_SETTINGS = {
|
||||
#### Core tests to always run in the CI
|
||||
"llava": VLMTestInfo(
|
||||
models=["llava-hf/llava-1.5-7b-hf"],
|
||||
test_type=(
|
||||
VLMTestType.EMBEDDING,
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.CUSTOM_INPUTS
|
||||
),
|
||||
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
||||
convert_assets_to_embeddings=model_utils.get_llava_embeddings,
|
||||
max_model_len=4096,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs(
|
||||
formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:"
|
||||
),
|
||||
limit_mm_per_prompt={"image": 4},
|
||||
)],
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
"paligemma": VLMTestInfo(
|
||||
models=["google/paligemma-3b-mix-224"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
prompt_formatter=identity,
|
||||
img_idx_to_prompt = lambda idx: "",
|
||||
# Paligemma uses its own sample prompts because the default one fails
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "caption es",
|
||||
"cherry_blossom": "What is in the picture?",
|
||||
}),
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||
dtype="half" if current_platform.is_rocm() else ("half", "float"),
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
"qwen2_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen2-VL-2B-Instruct"],
|
||||
test_type=(
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.MULTI_IMAGE,
|
||||
VLMTestType.VIDEO
|
||||
),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
|
||||
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
marks=[pytest.mark.core_model],
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
),
|
||||
#### Extended model tests
|
||||
"blip2": VLMTestInfo(
|
||||
models=["Salesforce/blip2-opt-2.7b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
@ -151,25 +208,6 @@ VLM_TEST_SETTINGS = {
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
"llava": VLMTestInfo(
|
||||
models=["llava-hf/llava-1.5-7b-hf"],
|
||||
test_type=(
|
||||
VLMTestType.EMBEDDING,
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.CUSTOM_INPUTS
|
||||
),
|
||||
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
||||
convert_assets_to_embeddings=model_utils.get_llava_embeddings,
|
||||
max_model_len=4096,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs(
|
||||
formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:"
|
||||
),
|
||||
limit_mm_per_prompt={"image": 4},
|
||||
)],
|
||||
),
|
||||
"llava_next": VLMTestInfo(
|
||||
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS),
|
||||
@ -200,12 +238,12 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
|
||||
# Llava-one-vision tests fixed sizes & the default size factors
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
runner_mm_key="videos",
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
|
||||
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
),
|
||||
limit_mm_per_prompt={"video": 4},
|
||||
runner_mm_key="videos",
|
||||
)],
|
||||
),
|
||||
# FIXME
|
||||
@ -218,9 +256,11 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||
runner_mm_key="videos",
|
||||
marks=[
|
||||
pytest.mark.skip(reason="LLava next video tests currently fail.")
|
||||
pytest.mark.skipif(
|
||||
transformers.__version__.startswith("4.46"),
|
||||
reason="Model broken with changes in transformers 4.46"
|
||||
)
|
||||
],
|
||||
),
|
||||
"minicpmv": VLMTestInfo(
|
||||
@ -234,23 +274,6 @@ VLM_TEST_SETTINGS = {
|
||||
postprocess_inputs=model_utils.wrap_inputs_post_processor,
|
||||
hf_output_post_proc=model_utils.minicmpv_trunc_hf_output,
|
||||
),
|
||||
"paligemma": VLMTestInfo(
|
||||
models=["google/paligemma-3b-mix-224"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
prompt_formatter=identity,
|
||||
img_idx_to_prompt = lambda idx: "",
|
||||
# Paligemma uses its own sample prompts because the default one fails
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "caption es",
|
||||
"cherry_blossom": "What is in the picture?",
|
||||
}),
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||
dtype="half" if current_platform.is_rocm() else ("half", "float"),
|
||||
),
|
||||
# Tests for phi3v currently live in another file because of a bug in
|
||||
# transformers. Once this issue is fixed, we can enable them here instead.
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
|
@ -56,6 +56,17 @@ def qwen_vllm_to_hf_output(
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def qwen2_vllm_to_hf_output(
|
||||
vllm_output: RunnerOutput,
|
||||
model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]:
|
||||
"""Sanitize vllm output [qwen2 models] to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "<|im_end|>"
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
|
||||
model: str) -> RunnerOutput:
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
@ -29,6 +29,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
|
||||
num_logprobs=test_case.num_logprobs,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||
runner_mm_key="images",
|
||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||
|
||||
|
||||
@ -51,6 +52,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
|
||||
num_logprobs=test_case.num_logprobs,
|
||||
limit_mm_per_prompt={"image": len(image_assets)},
|
||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||
runner_mm_key="images",
|
||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||
|
||||
|
||||
@ -74,6 +76,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
vllm_embeddings=vllm_embeddings,
|
||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||
runner_mm_key="images",
|
||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||
|
||||
|
||||
@ -101,6 +104,7 @@ def run_video_test(
|
||||
num_logprobs=test_case.num_logprobs,
|
||||
limit_mm_per_prompt={"video": len(video_assets)},
|
||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||
runner_mm_key="videos",
|
||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||
|
||||
|
||||
@ -115,7 +119,11 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
|
||||
|
||||
inputs = test_case.custom_test_opts.inputs
|
||||
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
|
||||
assert inputs is not None and limit_mm_per_prompt is not None
|
||||
runner_mm_key = test_case.custom_test_opts.runner_mm_key
|
||||
# Inputs, limit_mm_per_prompt, and runner_mm_key should all be set
|
||||
assert inputs is not None
|
||||
assert limit_mm_per_prompt is not None
|
||||
assert runner_mm_key is not None
|
||||
|
||||
core.run_test(
|
||||
hf_runner=hf_runner,
|
||||
@ -127,4 +135,5 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
|
||||
num_logprobs=test_case.num_logprobs,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||
runner_mm_key=runner_mm_key,
|
||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||
|
@ -52,6 +52,8 @@ class SizeType(Enum):
|
||||
class CustomTestOptions(NamedTuple):
|
||||
inputs: List[Tuple[List[str], List[Union[List[Image], Image]]]]
|
||||
limit_mm_per_prompt: Dict[str, int]
|
||||
# kwarg to pass multimodal data in as to vllm/hf runner instances.
|
||||
runner_mm_key: str = "images"
|
||||
|
||||
|
||||
class ImageSizeWrapper(NamedTuple):
|
||||
@ -141,9 +143,6 @@ class VLMTestInfo(NamedTuple):
|
||||
Callable[[PosixPath, str, Union[List[ImageAsset], _ImageAssets]],
|
||||
str]] = None # noqa: E501
|
||||
|
||||
# kwarg to pass multimodal data in as to vllm/hf runner instances
|
||||
runner_mm_key: str = "images"
|
||||
|
||||
# Allows configuring a test to run with custom inputs
|
||||
custom_test_opts: Optional[List[CustomTestOptions]] = None
|
||||
|
||||
@ -168,7 +167,6 @@ class VLMTestInfo(NamedTuple):
|
||||
"get_stop_token_ids": self.get_stop_token_ids,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"patch_hf_runner": self.patch_hf_runner,
|
||||
"runner_mm_key": self.runner_mm_key,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing import List, Type
|
||||
|
||||
import pytest
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from transformers import AutoModelForVision2Seq
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
@ -85,8 +86,8 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
# FIXME
|
||||
@pytest.mark.skip(reason="LLava next embedding tests currently fail")
|
||||
@pytest.mark.skipif(transformers.__version__.startswith("4.46"),
|
||||
reason="Model broken with changes in transformers 4.46")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models_text(
|
||||
|
Loading…
x
Reference in New Issue
Block a user