[CI/Build] Split up VLM tests (#11083)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-12 06:18:16 +08:00 committed by GitHub
parent 72ff3a9686
commit d1e21a979b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 50 deletions

View File

@ -321,7 +321,7 @@ steps:
##### models test ##### ##### models test #####
- label: Basic Models Test # 30min - label: Basic Models Test # 24min
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models - tests/models
@ -331,7 +331,7 @@ steps:
- pytest -v -s models/test_registry.py - pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py - pytest -v -s models/test_initialization.py
- label: Language Models Test (Standard) # 42min - label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -342,7 +342,7 @@ steps:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model - pytest -v -s models/embedding/language -m core_model
- label: Language Models Test (Extended) # 50min - label: Language Models Test (Extended) # 1h10min
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -353,7 +353,7 @@ steps:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model' - pytest -v -s models/embedding/language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 26min - label: Multi-Modal Models Test (Standard) # 28min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -369,7 +369,7 @@ steps:
- pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
- label: Multi-Modal Models Test (Extended) # 1h15m - label: Multi-Modal Models Test (Extended) 1 # 1h16m
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -380,14 +380,24 @@ steps:
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug # HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307 # https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
- label: Multi-Modal Models Test (Extended) 2 # 38m
optional: true
source_file_dependencies:
- vllm/
- tests/models/decoder_only/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model'
# This test is used only in PR development phase to test individual models and should never run on main # This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test - label: Custom Models Test
optional: true optional: true
@ -446,11 +456,11 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py - pytest -v -s distributed/test_distributed_oot.py
@ -540,7 +550,7 @@ steps:
# see https://github.com/vllm-project/vllm/pull/5689 for details # see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py
- torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s -x lora/test_mixtral.py - pytest -v -s -x lora/test_mixtral.py
- label: LM Eval Large Models # optional - label: LM Eval Large Models # optional

View File

@ -96,7 +96,8 @@ markers = [
"core_model: enable this model test in each PR instead of only nightly", "core_model: enable this model test in each PR instead of only nightly",
"cpu_model: enable this model test in CPU tests", "cpu_model: enable this model test in CPU tests",
"quant_model: run this model test under Quantized category", "quant_model: run this model test under Quantized category",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "split: run this test as part of a split",
"distributed: run this test only in distributed GPU tests",
"skip_v1: do not run this test with v1", "skip_v1: do not run this test with v1",
"optional: optional tests that are automatically skipped, include --optional to run them", "optional: optional tests that are automatically skipped, include --optional to run them",
] ]

View File

@ -1,7 +1,9 @@
"""Common tests for testing .generate() functionality for single / multiple """Common tests for testing .generate() functionality for single / multiple
image, embedding, and video support for different VLMs in vLLM. image, embedding, and video support for different VLMs in vLLM.
""" """
import math
import os import os
from collections import defaultdict
from pathlib import PosixPath from pathlib import PosixPath
from typing import Type from typing import Type
@ -10,11 +12,12 @@ from transformers import AutoModelForVision2Seq
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, identity from vllm.utils import identity
from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets,
_VideoAssets) _VideoAssets)
from ....utils import fork_new_process_for_each_test, large_gpu_mark from ....utils import (fork_new_process_for_each_test, large_gpu_mark,
multi_gpu_marks)
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils import custom_inputs, model_utils, runners
from .vlm_utils.case_filtering import get_parametrized_options from .vlm_utils.case_filtering import get_parametrized_options
@ -382,7 +385,7 @@ VLM_TEST_SETTINGS = {
prompt_path_encoder=model_utils.qwen_prompt_path_encoder, prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
), ),
### Tensor parallel / multi-gpu broadcast tests ### Tensor parallel / multi-gpu broadcast tests
"broadcast-chameleon": VLMTestInfo( "chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"], models=["facebook/chameleon-7b"],
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096, max_model_len=4096,
@ -393,43 +396,25 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
hf_output_post_proc = lambda hf_output, model: hf_output[:2], hf_output_post_proc = lambda hf_output, model: hf_output[:2],
comparator=check_outputs_equal, comparator=check_outputs_equal,
marks=[ marks=multi_gpu_marks(num_gpus=2),
pytest.mark.distributed_2_gpus,
pytest.mark.skipif(
cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.",
),
],
**COMMON_BROADCAST_SETTINGS # type: ignore **COMMON_BROADCAST_SETTINGS # type: ignore
), ),
"broadcast-llava": VLMTestInfo( "llava-broadcast": VLMTestInfo(
models=["llava-hf/llava-1.5-7b-hf"], models=["llava-hf/llava-1.5-7b-hf"],
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096, max_model_len=4096,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
marks=[ marks=multi_gpu_marks(num_gpus=2),
pytest.mark.distributed_2_gpus,
pytest.mark.skipif(
cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.",
)
],
**COMMON_BROADCAST_SETTINGS # type: ignore **COMMON_BROADCAST_SETTINGS # type: ignore
), ),
"broadcast-llava_next": VLMTestInfo( "llava_next-broadcast": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"], models=["llava-hf/llava-v1.6-mistral-7b-hf"],
prompt_formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]", prompt_formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]",
max_model_len=10240, max_model_len=10240,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
marks=[ marks=multi_gpu_marks(num_gpus=2),
pytest.mark.distributed_2_gpus,
pytest.mark.skipif(
cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.",
)
],
**COMMON_BROADCAST_SETTINGS # type: ignore **COMMON_BROADCAST_SETTINGS # type: ignore
), ),
### Custom input edge-cases for specific models ### Custom input edge-cases for specific models
@ -468,6 +453,41 @@ VLM_TEST_SETTINGS = {
# yapf: enable # yapf: enable
def _mark_splits(
test_settings: dict[str, VLMTestInfo],
*,
num_groups: int,
) -> dict[str, VLMTestInfo]:
name_by_test_info_id = {id(v): k for k, v in test_settings.items()}
test_infos_by_model = defaultdict[str, list[VLMTestInfo]](list)
for info in test_settings.values():
for model in info.models:
test_infos_by_model[model].append(info)
models = sorted(test_infos_by_model.keys())
split_size = math.ceil(len(models) / num_groups)
new_test_settings = dict[str, VLMTestInfo]()
for i in range(num_groups):
models_in_group = models[i * split_size:(i + 1) * split_size]
for model in models_in_group:
for info in test_infos_by_model[model]:
new_marks = (info.marks or []) + [pytest.mark.split(group=i)]
new_info = info._replace(marks=new_marks)
new_test_settings[name_by_test_info_id[id(info)]] = new_info
missing_keys = test_settings.keys() - new_test_settings.keys()
assert not missing_keys, f"Missing keys: {missing_keys}"
return new_test_settings
VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
### Test wrappers ### Test wrappers
# Wrappers around the core test running func for: # Wrappers around the core test running func for:
# - single image # - single image

View File

@ -682,10 +682,12 @@ def fork_new_process_for_each_test(
def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
"""Gets a pytest skipif mark, which triggers ig the the device doesn't have """
meet a minimum memory requirement in gb; can be leveraged via Get a pytest mark, which skips the test if the GPU doesn't meet
@large_gpu_test to skip tests in environments without enough resources, or a minimum memory requirement in GB.
called when filtering tests to run directly.
This can be leveraged via `@large_gpu_test` to skip tests in environments
without enough resources, or called when filtering tests to run directly.
""" """
try: try:
if current_platform.is_cpu(): if current_platform.is_cpu():
@ -712,26 +714,37 @@ def large_gpu_test(*, min_gb: int):
Currently, the CI machine uses L4 GPU which has 24 GB VRAM. Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
""" """
test_skipif = large_gpu_mark(min_gb) mark = large_gpu_mark(min_gb)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_skipif(f) return mark(f)
return wrapper return wrapper
def multi_gpu_marks(*, num_gpus: int):
"""Get a collection of pytest marks to apply for `@multi_gpu_test`."""
test_selector = pytest.mark.distributed(num_gpus=num_gpus)
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
return [test_selector, test_skipif]
def multi_gpu_test(*, num_gpus: int): def multi_gpu_test(*, num_gpus: int):
""" """
Decorate a test to be run only when multiple GPUs are available. Decorate a test to be run only when multiple GPUs are available.
""" """
test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus") marks = multi_gpu_marks(num_gpus=num_gpus)
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_selector(test_skipif(fork_new_process_for_each_test(f))) func = fork_new_process_for_each_test(f)
for mark in reversed(marks):
func = mark(func)
return func
return wrapper return wrapper