[CI/Build] Split up models tests (#10069)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-10 03:39:14 +08:00 committed by GitHub
parent b09895a618
commit 51c2e1fcef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 115 additions and 129 deletions

View File

@ -305,7 +305,7 @@ steps:
##### models test #####
- label: Basic Models Test # 3min
- label: Basic Models Test # 10min
source_file_dependencies:
- vllm/
- tests/models
@ -314,23 +314,24 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/*.py --ignore=models/test_oot_registration.py
- label: Decoder-only Language Models Test (Standard) # 35min
- label: Decoder-only Language Models Test (Standard) # 18min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/decoder_only/language
commands:
- pytest -v -s models/decoder_only/language/test_models.py
- pytest -v -s models/decoder_only/language -m core_model
- pytest -v -s models/decoder_only/language -m quant_model
- label: Decoder-only Language Models Test (Extended) # 1h20min
- label: Decoder-only Language Models Test (Extended) # 46min
nightly: true
source_file_dependencies:
- vllm/
- tests/models/decoder_only/language
commands:
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- label: Decoder-only Multi-Modal Models Test (Standard) # 26min
- label: Decoder-only Multi-Modal Models Test (Standard) # 22min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
@ -339,21 +340,24 @@ steps:
commands:
- pytest -v -s models/decoder_only/audio_language -m core_model
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
# No tests under this group for now
# - pytest -v -s models/decoder_only/audio_language -m quant_model
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
- label: Decoder-only Multi-Modal Models Test (Extended)
- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m
nightly: true
source_file_dependencies:
- vllm/
- tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language
commands:
- pytest -v -s models/decoder_only/audio_language -m 'not core_model'
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- label: Other Models Test # 6min
- label: Other Models Test # 20min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/

View File

@ -95,6 +95,7 @@ markers = [
"skip_global_cleanup",
"core_model: enable this model test in each PR instead of only nightly",
"cpu_model: enable this model test in CPU tests",
"quant_model: run this model test under Quantized category",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1",
]

View File

@ -38,6 +38,7 @@ ground_truth_generations = [
]
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])

View File

@ -15,6 +15,7 @@ from ...utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize(

View File

@ -17,26 +17,21 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
# FIXME: Move this to confest
MODELS = [
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
]
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
("meta-llama/Llama-3.2-1B-Instruct",
"bartowski/Llama-3.2-1B-Instruct-GGUF",
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_k_m.gguf"),
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@ -45,7 +40,9 @@ def test_models(
num_gpus_available,
vllm_runner,
example_prompts,
model,
original_model,
gguf_id,
gguf_path,
dtype: str,
max_tokens: int,
num_logprobs: int,
@ -54,7 +51,7 @@ def test_models(
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
original_model, gguf_model = model
gguf_model = hf_hub_download(gguf_id, filename=gguf_path)
tokenizer = AutoTokenizer.from_pretrained(original_model)
messages = [[{

View File

@ -33,6 +33,7 @@ MODELS = [
]
@pytest.mark.quant_model
@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")

View File

@ -38,6 +38,7 @@ model_pairs = [
]
@pytest.mark.quant_model
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
reason="Marlin24 is not supported on this GPU type.")

View File

@ -7,7 +7,9 @@ import pytest
from ...utils import check_logprobs_close
MODELS = [
# TODO(sang): Sliding window should be tested separately.
"ibm/PowerLM-3b",
"ibm/PowerMoE-3b",
]
@ -24,7 +26,6 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:
# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

View File

@ -1,39 +0,0 @@
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import pytest
from ...utils import check_logprobs_close
MODELS = [
"ibm/PowerMoE-3b",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -39,6 +39,7 @@ EXPECTED_STRS_MAP = {
@pytest.mark.skip(
reason=
"Prevent unstable test based on golden strings from breaking the build.")
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)

View File

@ -1,8 +1,5 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_models.py`.
"""
import pytest
@ -35,6 +32,7 @@ if not current_platform.is_cpu():
target_dtype = "half"
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32])

View File

@ -56,11 +56,13 @@ def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
seq_len = 5000 # bigger than the max feature size for any image
seq_data, mm_data = dummy_data_for_llava_next(
dummy_data = dummy_data_for_llava_next(
ctx,
seq_len=seq_len,
mm_counts={"image": 1},
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size
assert mm_data["image"].height == expected_size[0]

View File

@ -131,12 +131,13 @@ def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
mm_processor_kwargs=None,
)
sequence_data, _, = dummy_data_for_phi3v(
dummy_data = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs

View File

@ -86,10 +86,17 @@ def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
# NOTE: video value is required, but isn't actually used
# when making the dummy data except for error handling currently
seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, {
"image": 1,
"video": 0
}, **mm_processor_kwargs)
dummy_data = dummy_data_for_qwen2_vl(
ctx=qwen2_vl_context,
seq_len=seq_len,
mm_counts={
"image": 1,
"video": 0
},
**mm_processor_kwargs,
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data
# Ensure we have the right number of placeholders for min/max pixel values
assert seq_data.get_token_ids().count(image_token_id) == token_count

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Type
import pytest
import torch
@ -19,7 +19,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
def run_awq_test(
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
models: Tuple[str, str],
source_model: str,
quant_model: str,
*,
size_factors: List[float],
dtype: str,
@ -28,8 +29,6 @@ def run_awq_test(
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
source_model, quant_model = models
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -84,8 +83,11 @@ def run_awq_test(
)
@pytest.mark.quant_model
@pytest.mark.parametrize(
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
("source_model", "quant_model"),
[("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")],
)
@pytest.mark.parametrize(
"size_factors",
[
@ -103,12 +105,13 @@ def run_awq_test(
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode()
def test_awq_models(vllm_runner, image_assets, models, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
def test_awq_models(vllm_runner, image_assets, source_model, quant_model,
size_factors, dtype, max_tokens, num_logprobs) -> None:
run_awq_test(
vllm_runner,
image_assets,
models,
source_model,
quant_model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,

View File

@ -11,21 +11,17 @@ from ....conftest import _ImageAssets
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [
snapshot_download("OpenGVLab/InternViT-300M-448px",
allow_patterns=DOWNLOAD_PATTERN),
snapshot_download("OpenGVLab/InternViT-6B-448px-V1-5",
allow_patterns=DOWNLOAD_PATTERN),
]
def run_intern_vit_test(
image_assets: _ImageAssets,
model: str,
model_id: str,
*,
dtype: str,
distributed_executor_backend: Optional[str] = None,
):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets]
pixel_values = [
@ -67,12 +63,15 @@ def run_intern_vit_test(
assert cos_similar(vllm_output, hf_output).mean() > 0.99
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("model_id", [
"OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5",
])
@pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode()
def test_models(dist_init, image_assets, model, dtype: str) -> None:
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
run_intern_vit_test(
image_assets,
model,
model_id,
dtype=dtype,
)

View File

@ -130,8 +130,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"blip2": VLMTestInfo(
@ -159,9 +159,9 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
marks=[
pytest.mark.skipif(
transformers.__version__.startswith("4.46"),
transformers.__version__ < "4.46.2",
reason="Model broken in HF, see huggingface/transformers#34379"
)
),
]
),
"fuyu": VLMTestInfo(
@ -185,8 +185,8 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
marks=[large_gpu_mark(min_gb=48)],
patch_hf_runner=model_utils.glm_patch_hf_runner,
marks=[large_gpu_mark(min_gb=48)],
),
"h2ovl": VLMTestInfo(
models = [
@ -205,6 +205,22 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[
pytest.mark.skipif(
transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0"
),
large_gpu_mark(min_gb=48),
],
),
"intern_vl": VLMTestInfo(
models=[
"OpenGVLab/InternVL2-1B",
@ -263,7 +279,6 @@ VLM_TEST_SETTINGS = {
runner_mm_key="videos",
)],
),
# FIXME
"llava_next_video": VLMTestInfo(
models=["llava-hf/LLaVA-NeXT-Video-7B-hf"],
test_type=VLMTestType.VIDEO,
@ -275,7 +290,7 @@ VLM_TEST_SETTINGS = {
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
marks=[
pytest.mark.skipif(
transformers.__version__.startswith("4.46"),
transformers.__version__ < "4.46.2",
reason="Model broken with changes in transformers 4.46"
)
],
@ -316,6 +331,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[large_gpu_mark(min_gb=48)],
),
"qwen": VLMTestInfo(
models=["Qwen/Qwen-VL"],
@ -327,22 +343,6 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[
pytest.mark.skipif(
transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0"
),
large_gpu_mark(min_gb=48),
],
),
### Tensor parallel / multi-gpu broadcast tests
"broadcast-chameleon": VLMTestInfo(
models=["facebook/chameleon-7b"],
@ -362,7 +362,7 @@ VLM_TEST_SETTINGS = {
reason="Need at least 2 GPUs to run the test.",
),
pytest.mark.skipif(
transformers.__version__.startswith("4.46"),
transformers.__version__ < "4.46.2",
reason="Model broken in HF, see huggingface/transformers#34379"
)
],

View File

@ -1,7 +1,8 @@
import copy
import enum
import json
import warnings
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
@ -2078,6 +2079,12 @@ class VllmConfig:
return quant_config
return None
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config
return replace(self, model_config=model_config)
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""

View File

@ -229,7 +229,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
@ -246,9 +245,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)
self.language_model = PersimmonForCausalLM(
vllm_config.with_hf_config(config.text_config))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

View File

@ -164,10 +164,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__(vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config,
cache_config,
quant_config,

View File

@ -241,11 +241,11 @@ def init_vllm_registered_model(
based on the arguments passed to the outer vLLM model.
"""
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
import copy
copied_config = copy.deepcopy(vllm_config)
copied_config.model_config.hf_config = hf_config
return model_class(vllm_config=copied_config, prefix=prefix)
return model_class(
vllm_config=vllm_config.with_hf_config(hf_config),
prefix=prefix,
)
@overload